[JAX] MXFP8 Grouped Quant+GEMM#2763
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile SummaryThis PR integrates the V2 MXFP8 grouped quantization kernel ( Confidence Score: 5/5PR is safe to merge; all remaining findings are minor P2 style issues with no runtime impact. Previously-flagged P1 issues (NameError on lhs_first_dims/lhs_last_dims, AttributeError on None.size) are resolved in the current version. The three new findings are all P2: a missing space in an error message string, an uninitialised-but-unused output buffer in the V2 quantize C++ handler, and counterintuitive variable naming in the dequantizer unswizzle helper. None of these affect runtime correctness. transformer_engine/jax/quantize/dequantizer.py (colwise unswizzle naming) and transformer_engine/jax/csrc/extensions/quantization.cpp (updated_amaxs not written). Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python (grouped_gemm)
participant GQv1 as GroupedQuantizeFFI (V1)
participant GQv2 as GroupedQuantizeV2FFI (V2)
participant GGv1 as GroupedGemmFFI (V1)
participant GGv2 as GroupedGemmV2FFI (V2)
PY->>PY: _use_v2_kernel? (SM100+, shape aligned)
alt V1 path (SM<100 or shape unaligned)
PY->>GQv1: x, scale, group_sizes
GQv1->>GQv1: nvte_quantize per group
GQv1->>GQv1: set_with_gemm_swizzled_scales(true)
GQv1-->>PY: pre-swizzled scale_inv
PY->>GGv1: lhs, rhs, pre-swizzled sinv
GGv1->>GGv1: GEMM (no re-swizzle needed)
GGv1-->>PY: output
else V2 path (SM100+, 128-aligned shapes)
PY->>GQv2: x, group_sizes, int64_workspace
GQv2->>GQv2: nvte_convert_int32_to_int64_with_multiplier
GQv2->>GQv2: nvte_compute_grouped_tensor_offsets
GQv2->>GQv2: nvte_group_quantize (fused swizzle)
GQv2-->>PY: pre-swizzled scale_inv + int64_workspace
PY->>GGv2: lhs, rhs, pre-swizzled sinv, alpha/beta
GGv2->>GGv2: MXFP8 grouped GEMM (CUDA-graph safe)
GGv2-->>PY: output
end
Reviews (6): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| assert False, ( | ||
| "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got" | ||
| " scaling_mode {}".format(scaling_mode) | ||
| ) |
There was a problem hiding this comment.
assert False makes fallback unreachable
The assert False statements at lines 1028, 1036, and 1045 will always raise AssertionError before the return False on the next line, making those returns dead code. More critically, if Python is run with optimizations enabled (-O flag, which disables asserts), the assert False becomes a no-op and execution falls through — the function would silently skip the validation and continue to later checks or return True, potentially routing data to the V2 kernel under unsupported conditions.
These should be changed to raise an explicit exception or simply return False (if fallback to V1 is the intended behavior) without using assert:
| assert False, ( | |
| "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got" | |
| " scaling_mode {}".format(scaling_mode) | |
| ) | |
| return False |
This same pattern repeats at lines 1036-1039 and 1044-1048.
| cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, | ||
| stream); | ||
| // Note: This may break cudaGraph. | ||
| cudaStreamSynchronize(stream); | ||
| } | ||
| // size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); | ||
| // if (!is_rhs_ragged) { | ||
| // NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, |
There was a problem hiding this comment.
Commented-out group_sizes sum validation
The validation that sum(group_sizes) matches m (or k for wgrad) has been commented out entirely. While the new *_first_dims/*_last_dims interface changes how dimensions are communicated, removing this runtime sanity check eliminates a useful guard against dimension mismatches that could lead to silent data corruption or out-of-bounds memory access. Consider either adapting this validation to work with the new interface or adding an equivalent check.
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
833cb3e to
2dd69d4
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci |
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci |
| supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes] | ||
|
|
||
| is_v2_grouped_gemm_supported = get_device_compute_capability(0) >= 100 | ||
| v2_grouped_gemm_unsupported_reason = "V2 grouped GEMM requires SM100+ (Blackwell or newer)" |
There was a problem hiding this comment.
maybe we should wrap this into utils somewhere, and reuse to guard the all calls to V2 grouped GEMM, not just from test_custom_call_compute
There was a problem hiding this comment.
probably would also make the def grouped_gemm in gemm.py shorter?
There was a problem hiding this comment.
Good idea, I've decided to simplify the test code to make it less coupled to V1/V2. I still have some comments to indicate which test cases should trigger V1/V2, but there is less V1/V2 logic in the tests themselves and it is left as more of an internal implementation detail.
There was a problem hiding this comment.
Separately, I've also simplified the grouped_gemm function as I agree that function body was too complex. It is now refactored into several helper functions. It could be cleaned up further, but it's at least better than it was previously. Thanks!
|
|
||
| # *32 so that the input shapes works for MXFP8 | ||
| input_shape = (m * 32, n) | ||
| # Use 128 multiplier for V2-eligible MXFP8 shapes (both M and K 128-aligned) |
There was a problem hiding this comment.
make it clearer that the 128 aligned is a cuBLASLt thing while 32 multiplier is a scaling factor for MXFP8 applying to chunks of 32 elements thing
There was a problem hiding this comment.
This isn't solely due to cuBLASLt. The grouped quantize kernel also has these alignment requirements. I've refactored this test code to be less coupled to the internal V1/V2 logic and instead tried to select a handful of test cases that should cover both V1 and V2, and whether V1 or V2 is selected is more of an implementation detail than visible at the test-level (except for some small notes next to the configs to show both V1 and V2 should be covered).
| return False | ||
| # V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both | ||
| # operands is a multiple of 128. The V2 GEMM setup kernel computes per-group | ||
| # scale pointers as ``data_offset / 32``, which equals ``K_blocks * last_dim``. |
There was a problem hiding this comment.
this took me a bit to understand, not sure if you should clarify what K_blocks is as it is not defined in this file. If after 2nd read and it still feels pretty trivial then feel free to SR
There was a problem hiding this comment.
I've reworded this so it's clearer. I believe we could support cases where this dim is not divisible by 128, there is no inherent limitation in the GEMM afaik. But currently the grouped quantize and grouped GEMM setup kernels do not handle these offsets correctly except for when this dim is divisible by 128 for simplicity
| # [n_groups int64 group_sizes | n_groups+1 int64 offsets] | ||
| # = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8. | ||
| n_groups = group_sizes_aval.size | ||
| fifth_out_aval = jax.core.ShapedArray(shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8) |
There was a problem hiding this comment.
fifth output seems like a bad name for this. Maybe group_sizes_and_offsets?
There was a problem hiding this comment.
oh I see that it is updated_amax for V1. Not sure what would be the best name here given that it is different purposes in the 2 versions
There was a problem hiding this comment.
Good point, this is a bad name. Instead of this overloaded 5th output, I've instead made both FFIs use 6 outputs and left the workspace empty on V1 for consistency. For V2, if we ever want to support delayed scaling we would need this updated amax output anyways
| if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: | ||
| return False | ||
| # Require SM100+ so V2 quantize (fused swizzle) is only used alongside V2 GEMM. | ||
| if get_min_device_compute_capability() < 100: |
There was a problem hiding this comment.
in gemm.py, you check for get_device_compute_capability but here it is get_min_device_capability. These would be okay if all GPUs on the systemm is the same compute cap (which is most of our products, maybe minus Galaxy ones, I don't remember clearly). But for consistency, please use the same thing.
There was a problem hiding this comment.
Same for the test file too
There was a problem hiding this comment.
Good catch, thanks! I've updated to get the changes in this PR to use min device capability
|
|
||
| // Computes exclusive prefix sums: offsets[0]=0, offsets[i]=sum(first_dims[0..i-1]*last_dim). | ||
| // Produces n_groups+1 values. Single-threaded sequential scan; n_groups is typically small. | ||
| __global__ void compute_grouped_tensor_offsets_kernel(const int64_t *first_dims, int64_t *offsets, |
There was a problem hiding this comment.
I have an idea for this in case n_groups ever gets large: do 32 threads cumsum in blocks then warp shfl to reduce local sums to 1 sum.
There was a problem hiding this comment.
Sounds good! Currently the kernel runtime is pretty small relative to our other kernels and our n_groups per device is fairly small with EP, but good idea for future if n_groups per device gets bigger
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
4234eca to
bf6377b
Compare
…mm-mxfp8 Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
2bf1e25 to
513108a
Compare
for more information, see https://pre-commit.ci
|
/te-ci |
Description
TE/JAX integrations of the V2 MXFP8 grouped quantization kernel and the V2 MXFP8 grouped GEMM which are both cuda-graph-safe.
Type of change
Changes
Checklist: