Bias Prob Scaling for GroupedLinear and Fused MOE Layers#2864
Bias Prob Scaling for GroupedLinear and Fused MOE Layers#2864vthumbe1503 wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds "bias probability scaling" (
Confidence Score: 4/5Safe to merge for non-MXFP8 paths; MXFP8 fused path has an unverified The non-quantized path (most users) is correct end-to-end: forward adds transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py — lines 430-452 where Important Files Changed
Sequence DiagramsequenceDiagram
participant Input
participant FC1 as GroupedLinear FC1
participant SwiGLU as ScaledSwiGLU (prob)
participant FC2 as GroupedLinear FC2 (scale_bias)
participant Triton as _compute_grouped_dbias_dscales
Note over Input,FC2: Forward pass
Input->>FC1: x to X @ W1^T + b1
FC1->>SwiGLU: h1
SwiGLU->>FC2: SwiGLU(h1) * prob
FC2->>FC2: GEMM to (SwiGLU*prob) @ W2^T
FC2->>FC2: + bias * prob [scale_bias=True]
FC2-->>Input: output
Note over Input,Triton: Backward pass
FC2->>Triton: dy, scales=prob, bias, offsets
Triton->>Triton: dbias[g] += sum_i dy[i]*prob[i]
Triton->>Triton: dscales[i] += sum_j dy[i,j]*bias[g,j]
Triton-->>FC2: dbias, dprob_from_bias
FC2->>SwiGLU: grad_scales = dprob_swiglu + dprob_from_bias
SwiGLU-->>Input: d_input via FC1 dgrad
Reviews (4): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| fc2_grad_extra = (None, None) if fc2_op._scale_bias else (None,) | ||
| return ( | ||
| grad_input, | ||
| [fc1_grad_params, (), fc2_grad_params], | ||
| [(None,), (grad_scales,), (None,)], | ||
| [(None,), (grad_scales,), fc2_grad_extra], |
There was a problem hiding this comment.
FC2 scales gradient silently routed through SwiGLU path — undocumented assumption
When fc2_op._scale_bias is True, fc2_grad_extra returns (None, None) — the gradient for FC2's second extra input (its routing-probability tensor) is not propagated here. Instead, the full grad_scales (SwiGLU prob gradient + FC2 bias contribution accumulated by _compute_grouped_dbias_dscales) is returned via (grad_scales,) in the SwiGLU slot.
This is only correct when the SwiGLU probability tensor and the FC2 probability tensor are the same Python object. The test enforces this by passing probs_test to both, but there is no runtime assertion. If a caller passes distinct tensors, the FC2-prob gradient is silently dropped. Consider adding a guard or documenting this constraint in the class docstring.
There was a problem hiding this comment.
This is true. If we were to check for this rigorously, we could check that the data pointers are the same between the SwiGLU and FC2. We would also want to check that FC1 and FC2 are getting the same splits. I don't think this is a blocker though.
| if scale_bias: | ||
| fc2_biases = fc2_op._get_bias_tensors(dtype) | ||
| bias_packed = torch.stack(fc2_biases) | ||
| scales_f32 = scales.detach().to(dtype=torch.float32) | ||
| fc2_dbias_packed_result, grad_scales = _compute_grouped_dbias_dscales( | ||
| fc2_dy, | ||
| scales_f32, | ||
| bias_packed, | ||
| split_sizes, | ||
| offsets=fc1_ctx.base_split_offsets, | ||
| dscales=grad_scales, | ||
| ) |
There was a problem hiding this comment.
scales here is the SwiGLU probability, not the FC2-specific probability
scales is restored from swiglu_ctx.saved_tensors (line 298) — it is the probability saved during the FC1/SwiGLU forward, not the FC2-specific routing probability. For the dbias/dscales computation of FC2 the FC2 probability should be used (d_bias2_g = Σ_i dy_i · prob2_i).
This works correctly today because the test always passes the same probs_test tensor to both ops, but it is an undocumented constraint. A comment explaining the shared-tensor assumption would help future maintainers.
| if dbias is None: | ||
| dbias = torch.zeros(num_groups, hidden, dtype=torch.float32, device=dy.device) | ||
| if dscales is None: | ||
| dscales = torch.zeros(total_tokens, dtype=torch.float32, device=dy.device) |
There was a problem hiding this comment.
No dtype validation for pre-allocated
dscales
When dscales is provided by the caller it is passed directly to tl.atomic_add in the Triton kernel, which requires a float32 pointer. If a caller passes a non-float32 tensor the atomic add silently corrupts the output. Consider adding an assertion:
if dscales is None:
dscales = torch.zeros(total_tokens, dtype=torch.float32, device=dy.device)
else:
assert dscales.dtype == torch.float32, (
f"_compute_grouped_dbias_dscales: dscales must be float32, got {dscales.dtype}"
)Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
|
Tip: Greploop — Automatically fix all review issues by running Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal. |
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: