Skip to content

Add grouped unswizzle functionality for MXFP8 scaling factors#2837

Open
int-smart wants to merge 2 commits intoNVIDIA:mainfrom
int-smart:feature/groupedUnswizzle
Open

Add grouped unswizzle functionality for MXFP8 scaling factors#2837
int-smart wants to merge 2 commits intoNVIDIA:mainfrom
int-smart:feature/groupedUnswizzle

Conversation

@int-smart
Copy link
Copy Markdown
Contributor

Description

Added grouped unswizzle functionality for MXFP8 scaling factors imitating the swizzle one. Added tests for the same.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added grouped unswizzle APIs and implementation in transformer_engine/common/swizzle/swizzle.cu and declarations in transformer_engine/common/include/transformer_engine/swizzle.h
  • Added/extended tests in tests/cpp/operator/test_swizzle.cu, including standalone grouped unswizzle and grouped swizzle→unswizzle round-trip coverage

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 5, 2026

Greptile Summary

This PR adds a grouped unswizzle operation for MXFP8 scaling factors, providing the inverse of the existing nvte_swizzle_grouped_scaling_factors. The implementation closely mirrors the swizzle path: new device kernels (grouped_unswizzle_row/col_scaling_uniform_shape_kernel) delegate to the pre-existing single-tensor unswizzle_row/col_scaling_kernel_impl, and the host function unswizzle_grouped_scaling_factors follows the same structure as its swizzle counterpart. A public C API entry point is declared in swizzle.h. Tests cover nine shape combinations (aligned, M-padding only, K-padding only, both) for both standalone unswizzle correctness (compared against a CPU reference) and a full swizzle→unswizzle round-trip.

Key concerns:

  • Missing input data guard (logic bug): unswizzle_grouped_scaling_factors derives has_rowwise_scale_inv/has_columnwise_scale_inv from the output tensor, but never checks that the corresponding input buffers have data. Passing a null input->scale_inv.dptr to the CUDA kernel will cause a memory fault. Every other unswizzle variant in the file guards against this; the same guard should be added here.
  • Input shape uniformity not validated: Only output->all_same_shape() is checked; input->all_same_shape() is not, unlike the swizzle path which validates its driving tensor's shape uniformity.

Confidence Score: 3/5

Not safe to merge as-is — missing input data availability check can cause a null-pointer dereference in the CUDA kernel at runtime.

The overall structure mirrors the well-tested swizzle path and test coverage is thorough across shape combinations, but the absence of the input->scale_inv.has_data() guard (present in every other unswizzle variant in this file) is a concrete logic bug that can crash at runtime when a caller provides an output tensor with scale data but a mismatched or empty input tensor.

transformer_engine/common/swizzle/swizzle.cu — specifically unswizzle_grouped_scaling_factors around lines 1735–1743.

Important Files Changed

Filename Overview
transformer_engine/common/swizzle/swizzle.cu Adds grouped unswizzle kernels and host function; missing input has_data() validation can cause a null-pointer dereference in the CUDA kernel.
transformer_engine/common/include/transformer_engine/swizzle.h Adds correct public C API declaration for nvte_unswizzle_grouped_scaling_factors, mirroring the swizzle counterpart with well-written doc comments.
tests/cpp/operator/test_swizzle.cu Adds standalone grouped unswizzle test and swizzle→unswizzle round-trip test with good shape coverage across aligned, M-padded, K-padded, and both-padded cases.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant API as nvte_unswizzle_grouped_scaling_factors
    participant Host as unswizzle_grouped_scaling_factors
    participant KernelRow as grouped_unswizzle_row_scaling_kernel
    participant KernelCol as grouped_unswizzle_col_scaling_kernel
    participant Impl as unswizzle_row/col_scaling_kernel_impl

    Caller->>API: input (swizzled GroupedTensor), output (compact GroupedTensor), stream
    API->>Host: convertNVTEGroupedTensorCheck(input/output)
    Host->>Host: validate scaling_mode, swizzle flags, shape uniformity
    Host->>Host: compute padded_m, padded_k, vec_load_size, slm_size
    alt has_rowwise_scale_inv
        Host->>KernelRow: launch(input_ptr, output_ptr, padded_m, padded_k, stride_bytes)
        KernelRow->>Impl: unswizzle_row_scaling_kernel_impl(...)
        Impl-->>KernelRow: unswizzled rowwise scale values
    end
    alt has_columnwise_scale_inv
        Host->>KernelCol: launch(input_ptr, output_ptr, padded_m, padded_k, stride_bytes)
        KernelCol->>Impl: unswizzle_col_scaling_kernel_impl(...)
        Impl-->>KernelCol: unswizzled columnwise scale values
    end
    Host-->>Caller: async result on CUDA stream
Loading

Reviews (1): Last reviewed commit: "Merge branch 'main' into feature/grouped..." | Re-trigger Greptile

Comment on lines +1735 to +1739
const bool has_rowwise_scale_inv = output->scale_inv.has_data();
const bool has_columnwise_scale_inv = output->columnwise_scale_inv.has_data();
if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) {
return;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing input data availability check

The function sets has_rowwise_scale_inv/has_columnwise_scale_inv from the output tensor's buffers, but never verifies that the corresponding input buffers also contain data. When has_rowwise_scale_inv is true but input->scale_inv.dptr is null, the kernel at line 1791 receives a null input pointer and will either produce garbage or trigger a CUDA memory fault.

Every other unswizzle variant in this file has this guard. For example, the non-grouped unswizzle_scaling_factors (lines 1220–1228):

if (has_rowwise_scale_inv) {
    NVTE_CHECK(input->scale_inv.has_data(),
               "Output tensor requests row-wise scaling factors, but input "
               "tensor does not provide them.");
} else if (has_columnwise_scale_inv) {
    NVTE_CHECK(input->columnwise_scale_inv.has_data(),
               "Output tensor requests column-wise scaling factors, but input "
               "tensor does not provide them.");
}

The same check should be inserted right after the early-return on line 1739.

Comment on lines +1741 to +1743
NVTE_CHECK(output->all_same_shape(), "Grouped unswizzle requires uniform tensor shapes.");
NVTE_CHECK(output->all_same_last_dim() && output->all_same_first_dim(),
"Grouped unswizzle requires uniform tensor shapes.");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Input shape uniformity not validated

Only the output tensor is checked for uniform shapes. The swizzle counterpart (swizzle_grouped_scaling_factors, line 1587) validates input->all_same_shape() on the driving tensor. For symmetric defensive coverage, an equivalent input->all_same_shape() check should be added here, since the kernel strides for the input buffer are computed from output-derived padded_m/padded_k. A mismatched input shape may not be caught by the downstream numel() checks if the total element counts happen to coincide.

Suggested change
NVTE_CHECK(output->all_same_shape(), "Grouped unswizzle requires uniform tensor shapes.");
NVTE_CHECK(output->all_same_last_dim() && output->all_same_first_dim(),
"Grouped unswizzle requires uniform tensor shapes.");
NVTE_CHECK(output->all_same_shape(), "Grouped unswizzle requires uniform tensor shapes.");
NVTE_CHECK(output->all_same_last_dim() && output->all_same_first_dim(),
"Grouped unswizzle requires uniform tensor shapes.");
NVTE_CHECK(input->all_same_shape(), "Grouped unswizzle requires uniform input tensor shapes.");

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant