Skip to content

Bias Prob Scaling for GroupedLinear and Fused MOE Layers#2864

Open
vthumbe1503 wants to merge 7 commits intoNVIDIA:mainfrom
vthumbe1503:grouped_bias_add
Open

Bias Prob Scaling for GroupedLinear and Fused MOE Layers#2864
vthumbe1503 wants to merge 7 commits intoNVIDIA:mainfrom
vthumbe1503:grouped_bias_add

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title bias*prob, dbias+dprob triton kernel Bias Prob Scaling for GroupedLinear and FusedGrouped Layers Apr 9, 2026
vthumbe1503 and others added 2 commits April 9, 2026 15:37
@vthumbe1503 vthumbe1503 requested a review from timmoon10 April 9, 2026 22:49
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 9, 2026

Greptile Summary

This PR adds "bias probability scaling" (scale_bias=True) to GroupedLinear and the fused MoE layers, so FC2 computes X @ W^T + bias * prob rather than X @ W^T + bias. It introduces a fused Triton kernel (_compute_grouped_dbias_dscales) that jointly computes dbias and d_prob in a single pass, wires the feature into both the non-quantized GroupedLinear.fuser_forward/backward and the MXFP8 cuDNN-fused forward/backward paths, and updates the test to validate the new semantics.

  • The MXFP8 fused forward path silently drops the FC2 bias if is_fc2_bias_supported() returns False while scale_bias=True; no runtime guard or fallback exists.
  • The semantics of prob_tensor in grouped_gemm_quant_wrapper_sm100 (FC2 forward) need verification: if it scales the full GEMM output rather than only the bias, the MXFP8 path will produce prob² · GEMM + bias · prob instead of the desired GEMM · prob + bias · prob, yielding incorrect forward activations and gradients on SM100.

Confidence Score: 4/5

Safe to merge for non-MXFP8 paths; MXFP8 fused path has an unverified prob_tensor semantic and a missing guard for older cuDNN.

The non-quantized path (most users) is correct end-to-end: forward adds bias*prob, the Triton kernel computes both dbias and dprob in one fused pass, and gradients are properly routed. The two P1 concerns are confined to the MXFP8 CuTe-DSL fused path (SM100-only, gated by env var): (1) prob_tensor semantics in grouped_gemm_quant_wrapper_sm100 need confirmation to rule out a prob² forward-pass error, and (2) when an older cuDNN doesn't expose bias_tensor, scale_bias=True silently produces no-bias output with no error.

transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py — lines 430-452 where prob_tensor and bias_tensor are set for the MXFP8 FC2 GEMM.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py FC2 GEMM's prob_tensor is now set to fc2_scales in the MXFP8 fused path; if this scales the full GEMM output (not just bias), it introduces a prob² factor. Additionally, when is_fc2_bias_supported() is False, bias is silently dropped with no fallback.
transformer_engine/pytorch/ops/basic/grouped_linear.py Non-MXFP8 forward/backward for scale_bias=True looks correct: bias is added as bias * prob in forward, and _compute_grouped_dbias_dscales properly computes dbias and d_prob in backward.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py MXFP8 backward correctly accumulates the bias contribution into grad_scales via the Triton kernel; the shared-tensor assumption for SwiGLU and FC2 probs remains undocumented but is acknowledged.
transformer_engine/pytorch/triton/grouped_dbias_dscales.py Wrapper validates dbias dtype (float32) but has a docstring inaccuracy: dbias return type is always float32, not "same dtype as dy".
transformer_engine/common/triton/grouped_dbias_dscales.py Triton kernel logic is correct: per-group dbias accumulates via registers + one atomic-add; per-token dscales accumulates via atomic-add per column tile. Empty groups (group_rows=0) are handled safely.
tests/pytorch/test_fusible_ops.py Reference implementation updated correctly to bias * prob; scale_bias=bias now passed to FC2; MXFP8+bias skip removed; parametrize decorators use bool iterables correctly.

Sequence Diagram

sequenceDiagram
    participant Input
    participant FC1 as GroupedLinear FC1
    participant SwiGLU as ScaledSwiGLU (prob)
    participant FC2 as GroupedLinear FC2 (scale_bias)
    participant Triton as _compute_grouped_dbias_dscales

    Note over Input,FC2: Forward pass
    Input->>FC1: x to X @ W1^T + b1
    FC1->>SwiGLU: h1
    SwiGLU->>FC2: SwiGLU(h1) * prob
    FC2->>FC2: GEMM to (SwiGLU*prob) @ W2^T
    FC2->>FC2: + bias * prob  [scale_bias=True]
    FC2-->>Input: output

    Note over Input,Triton: Backward pass
    FC2->>Triton: dy, scales=prob, bias, offsets
    Triton->>Triton: dbias[g] += sum_i dy[i]*prob[i]
    Triton->>Triton: dscales[i] += sum_j dy[i,j]*bias[g,j]
    Triton-->>FC2: dbias, dprob_from_bias
    FC2->>SwiGLU: grad_scales = dprob_swiglu + dprob_from_bias
    SwiGLU-->>Input: d_input via FC1 dgrad
Loading

Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +667 to +671
fc2_grad_extra = (None, None) if fc2_op._scale_bias else (None,)
return (
grad_input,
[fc1_grad_params, (), fc2_grad_params],
[(None,), (grad_scales,), (None,)],
[(None,), (grad_scales,), fc2_grad_extra],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 FC2 scales gradient silently routed through SwiGLU path — undocumented assumption

When fc2_op._scale_bias is True, fc2_grad_extra returns (None, None) — the gradient for FC2's second extra input (its routing-probability tensor) is not propagated here. Instead, the full grad_scales (SwiGLU prob gradient + FC2 bias contribution accumulated by _compute_grouped_dbias_dscales) is returned via (grad_scales,) in the SwiGLU slot.

This is only correct when the SwiGLU probability tensor and the FC2 probability tensor are the same Python object. The test enforces this by passing probs_test to both, but there is no runtime assertion. If a caller passes distinct tensors, the FC2-prob gradient is silently dropped. Consider adding a guard or documenting this constraint in the class docstring.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true. If we were to check for this rigorously, we could check that the data pointers are the same between the SwiGLU and FC2. We would also want to check that FC1 and FC2 are getting the same splits. I don't think this is a blocker though.

Comment on lines +488 to +499
if scale_bias:
fc2_biases = fc2_op._get_bias_tensors(dtype)
bias_packed = torch.stack(fc2_biases)
scales_f32 = scales.detach().to(dtype=torch.float32)
fc2_dbias_packed_result, grad_scales = _compute_grouped_dbias_dscales(
fc2_dy,
scales_f32,
bias_packed,
split_sizes,
offsets=fc1_ctx.base_split_offsets,
dscales=grad_scales,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 scales here is the SwiGLU probability, not the FC2-specific probability

scales is restored from swiglu_ctx.saved_tensors (line 298) — it is the probability saved during the FC1/SwiGLU forward, not the FC2-specific routing probability. For the dbias/dscales computation of FC2 the FC2 probability should be used (d_bias2_g = Σ_i dy_i · prob2_i).

This works correctly today because the test always passes the same probs_test tensor to both ops, but it is an undocumented constraint. A comment explaining the shared-tensor assumption would help future maintainers.

Comment on lines +60 to +63
if dbias is None:
dbias = torch.zeros(num_groups, hidden, dtype=torch.float32, device=dy.device)
if dscales is None:
dscales = torch.zeros(total_tokens, dtype=torch.float32, device=dy.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No dtype validation for pre-allocated dscales

When dscales is provided by the caller it is passed directly to tl.atomic_add in the Triton kernel, which requires a float32 pointer. If a caller passes a non-float32 tensor the atomic add silently corrupts the output. Consider adding an assertion:

if dscales is None:
    dscales = torch.zeros(total_tokens, dtype=torch.float32, device=dy.device)
else:
    assert dscales.dtype == torch.float32, (
        f"_compute_grouped_dbias_dscales: dscales must be float32, got {dscales.dtype}"
    )

@vthumbe1503 vthumbe1503 changed the title Bias Prob Scaling for GroupedLinear and FusedGrouped Layers Bias Prob Scaling for GroupedLinear and Fused MOE Layers Apr 9, 2026
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 9, 2026

Tip:

Greploop — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.

Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

timmoon10
timmoon10 previously approved these changes Apr 9, 2026
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
timmoon10
timmoon10 previously approved these changes Apr 10, 2026
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants