Skip to content

[JAX] MXFP8 Grouped Quant+GEMM#2763

Open
jberchtold-nvidia wants to merge 71 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-mxfp8
Open

[JAX] MXFP8 Grouped Quant+GEMM#2763
jberchtold-nvidia wants to merge 71 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-mxfp8

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented Mar 14, 2026

Description

TE/JAX integrations of the V2 MXFP8 grouped quantization kernel and the V2 MXFP8 grouped GEMM which are both cuda-graph-safe.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add new primitive and FFI for V2 grouped quantize that currently only supports MXFP8
  • Extend V2 grouped GEMM to support MXFP8
  • For both V1 and V2, move swizzling from grouped GEMM FFI to grouped quantize FFI. This is required because currently V2 can only do swizzling when fused with quantization; an independent swizzle kernel that supports ragged groups is not available.
    • This entails updating the tests and dequantization logic for Q->DQ tests to support preswizzled scales.
  • Some small kernels added to TE common to handle int32 -> int64 and offset calculations due to JAX's int32 dtype limitation

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

jberchtold-nvidia and others added 24 commits March 9, 2026 15:42
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft March 14, 2026 17:25
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 14, 2026

Greptile Summary

This PR integrates the V2 MXFP8 grouped quantization kernel (nvte_group_quantize) and extends V2 grouped GEMM to support MXFP8 in the JAX backend. Both paths are CUDA-graph-safe. A key architectural change moves scale_inv swizzling from the grouped GEMM FFI into the grouped quantize FFI for both V1 and V2, requiring updates to the dequantizer to unswizzle for testing. Previous review concerns (NameError on lhs_first_dims, AttributeError on None.size) appear to be addressed in the current version.

Confidence Score: 5/5

PR is safe to merge; all remaining findings are minor P2 style issues with no runtime impact.

Previously-flagged P1 issues (NameError on lhs_first_dims/lhs_last_dims, AttributeError on None.size) are resolved in the current version. The three new findings are all P2: a missing space in an error message string, an uninitialised-but-unused output buffer in the V2 quantize C++ handler, and counterintuitive variable naming in the dequantizer unswizzle helper. None of these affect runtime correctness.

transformer_engine/jax/quantize/dequantizer.py (colwise unswizzle naming) and transformer_engine/jax/csrc/extensions/quantization.cpp (updated_amaxs not written).

Important Files Changed

Filename Overview
transformer_engine/jax/csrc/extensions/quantization.cpp Adds GroupedQuantizeV2FFI for CUDA-graph-safe MXFP8 quantization; V1 path gains pre-swizzled scale_inv output. The updated_amaxs Result_Type buffer is accepted but never written to in the V2 handler.
transformer_engine/jax/cpp_extensions/gemm.py Refactors grouped_gemm into helpers (_quantize_inputs_if_needed, _get_num_gemms, _adjust_contracting_dims_for_hopper_fp8_transpose); wires V2 GEMM for MXFP8; previously-flagged NameError and AttributeError issues are resolved.
transformer_engine/jax/cpp_extensions/quantization.py Adds GroupedQuantizePrimitive with V2 path selection via _use_v2_kernel; introduces int64_workspace abstract for CUDA-graph-safe offset computation. Previously-flagged assert False pattern is addressed.
transformer_engine/jax/quantize/dequantizer.py Adds _unswizzle_mxfp8_grouped_scale to invert the GEMM-swizzled layout for both V1 and V2; dequantizer now flattens to 2D before calling _dequantize_func. Colwise branch uses counterintuitive variable naming (cols, rows = padded_scale_2d where first element is M//32).
transformer_engine/jax/quantize/tensor.py Adds pre_swizzled field to GroupedScaledTensor1x and threads it through ScaledTensorFactory; adds group_sizes property. pre_swizzled is static metadata (in aux_data), so pytree structure updates correctly.
transformer_engine/jax/csrc/extensions/gemm.cpp Extends make_grouped_tensor with MXFP8/colwise overloads; V2 GEMM now supports MXFP8 by consuming pre-swizzled scale_inv directly; removes the old per-GEMM swizzle loop from V1 path.
transformer_engine/jax/flax/module.py Lifts the unconditional ValueError for quantized grouped GEMM; now allows MXFP8BlockScaling and threads quantization_checkpoint_name through wrap_function_in_te_state_module.
tests/jax/test_custom_call_compute.py Extends grouped quantize and grouped dense tests with V2-eligible shapes and group_size_multiplier parametrization; adds skip guard for V2 kernel + non-128-aligned group sizes.

Sequence Diagram

sequenceDiagram
    participant PY as Python (grouped_gemm)
    participant GQv1 as GroupedQuantizeFFI (V1)
    participant GQv2 as GroupedQuantizeV2FFI (V2)
    participant GGv1 as GroupedGemmFFI (V1)
    participant GGv2 as GroupedGemmV2FFI (V2)

    PY->>PY: _use_v2_kernel? (SM100+, shape aligned)

    alt V1 path (SM<100 or shape unaligned)
        PY->>GQv1: x, scale, group_sizes
        GQv1->>GQv1: nvte_quantize per group
        GQv1->>GQv1: set_with_gemm_swizzled_scales(true)
        GQv1-->>PY: pre-swizzled scale_inv
        PY->>GGv1: lhs, rhs, pre-swizzled sinv
        GGv1->>GGv1: GEMM (no re-swizzle needed)
        GGv1-->>PY: output
    else V2 path (SM100+, 128-aligned shapes)
        PY->>GQv2: x, group_sizes, int64_workspace
        GQv2->>GQv2: nvte_convert_int32_to_int64_with_multiplier
        GQv2->>GQv2: nvte_compute_grouped_tensor_offsets
        GQv2->>GQv2: nvte_group_quantize (fused swizzle)
        GQv2-->>PY: pre-swizzled scale_inv + int64_workspace
        PY->>GGv2: lhs, rhs, pre-swizzled sinv, alpha/beta
        GGv2->>GGv2: MXFP8 grouped GEMM (CUDA-graph safe)
        GGv2-->>PY: output
    end
Loading

Reviews (6): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +1028 to +1031
assert False, (
"V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got"
" scaling_mode {}".format(scaling_mode)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert False makes fallback unreachable

The assert False statements at lines 1028, 1036, and 1045 will always raise AssertionError before the return False on the next line, making those returns dead code. More critically, if Python is run with optimizations enabled (-O flag, which disables asserts), the assert False becomes a no-op and execution falls through — the function would silently skip the validation and continue to later checks or return True, potentially routing data to the V2 kernel under unsupported conditions.

These should be changed to raise an explicit exception or simply return False (if fallback to V1 is the intended behavior) without using assert:

Suggested change
assert False, (
"V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got"
" scaling_mode {}".format(scaling_mode)
)
return False

This same pattern repeats at lines 1036-1039 and 1044-1048.

Comment on lines +1078 to +1085
cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
stream);
// Note: This may break cudaGraph.
cudaStreamSynchronize(stream);
}
// size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
// if (!is_rhs_ragged) {
// NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented-out group_sizes sum validation

The validation that sum(group_sizes) matches m (or k for wgrad) has been commented out entirely. While the new *_first_dims/*_last_dims interface changes how dimensions are communicated, removing this runtime sanity check eliminates a useful guard against dimension mismatches that could lead to silent data corruption or out-of-bounds memory access. Consider either adapting this validation to work with the new interface or adding an equivalent check.

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
jberchtold-nvidia and others added 6 commits April 6, 2026 16:30
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

@jberchtold-nvidia jberchtold-nvidia changed the title [JAX] MXFP8 Grouped GEMM [JAX] MXFP8 Grouped Quant+GEMM Apr 8, 2026
@tdophung tdophung marked this pull request as ready for review April 8, 2026 21:58
supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes]

is_v2_grouped_gemm_supported = get_device_compute_capability(0) >= 100
v2_grouped_gemm_unsupported_reason = "V2 grouped GEMM requires SM100+ (Blackwell or newer)"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should wrap this into utils somewhere, and reuse to guard the all calls to V2 grouped GEMM, not just from test_custom_call_compute

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably would also make the def grouped_gemm in gemm.py shorter?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I've decided to simplify the test code to make it less coupled to V1/V2. I still have some comments to indicate which test cases should trigger V1/V2, but there is less V1/V2 logic in the tests themselves and it is left as more of an internal implementation detail.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, I've also simplified the grouped_gemm function as I agree that function body was too complex. It is now refactored into several helper functions. It could be cleaned up further, but it's at least better than it was previously. Thanks!


# *32 so that the input shapes works for MXFP8
input_shape = (m * 32, n)
# Use 128 multiplier for V2-eligible MXFP8 shapes (both M and K 128-aligned)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it clearer that the 128 aligned is a cuBLASLt thing while 32 multiplier is a scaling factor for MXFP8 applying to chunks of 32 elements thing

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't solely due to cuBLASLt. The grouped quantize kernel also has these alignment requirements. I've refactored this test code to be less coupled to the internal V1/V2 logic and instead tried to select a handful of test cases that should cover both V1 and V2, and whether V1 or V2 is selected is more of an implementation detail than visible at the test-level (except for some small notes next to the configs to show both V1 and V2 should be covered).

return False
# V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both
# operands is a multiple of 128. The V2 GEMM setup kernel computes per-group
# scale pointers as ``data_offset / 32``, which equals ``K_blocks * last_dim``.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this took me a bit to understand, not sure if you should clarify what K_blocks is as it is not defined in this file. If after 2nd read and it still feels pretty trivial then feel free to SR

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've reworded this so it's clearer. I believe we could support cases where this dim is not divisible by 128, there is no inherent limitation in the GEMM afaik. But currently the grouped quantize and grouped GEMM setup kernels do not handle these offsets correctly except for when this dim is divisible by 128 for simplicity

# [n_groups int64 group_sizes | n_groups+1 int64 offsets]
# = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8.
n_groups = group_sizes_aval.size
fifth_out_aval = jax.core.ShapedArray(shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fifth output seems like a bad name for this. Maybe group_sizes_and_offsets?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see that it is updated_amax for V1. Not sure what would be the best name here given that it is different purposes in the 2 versions

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, this is a bad name. Instead of this overloaded 5th output, I've instead made both FFIs use 6 outputs and left the workspace empty on V1 for consistency. For V2, if we ever want to support delayed scaling we would need this updated amax output anyways

if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING:
return False
# Require SM100+ so V2 quantize (fused swizzle) is only used alongside V2 GEMM.
if get_min_device_compute_capability() < 100:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in gemm.py, you check for get_device_compute_capability but here it is get_min_device_capability. These would be okay if all GPUs on the systemm is the same compute cap (which is most of our products, maybe minus Galaxy ones, I don't remember clearly). But for consistency, please use the same thing.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for the test file too

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks! I've updated to get the changes in this PR to use min device capability


// Computes exclusive prefix sums: offsets[0]=0, offsets[i]=sum(first_dims[0..i-1]*last_dim).
// Produces n_groups+1 values. Single-threaded sequential scan; n_groups is typically small.
__global__ void compute_grouped_tensor_offsets_kernel(const int64_t *first_dims, int64_t *offsets,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have an idea for this in case n_groups ever gets large: do 32 threads cumsum in blocks then warp shfl to reduce local sums to 1 sum.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Currently the kernel runtime is pretty small relative to our other kernels and our n_groups per device is fairly small with EP, but good idea for future if n_groups per device gets bigger

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-mxfp8

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants