Draft
Conversation
… cu_seqlens - Use per-step cu_seqlens_q_padded to select Q chunks instead of tensor slicing - Use padded cu_seqlens_kv for K/V reordering (ensures divisibility) - Add cu_seqlens_kv and cu_seqlens_kv_padded to AllGather function signature - Compute per-step Q and KV cu_seqlens correctly from actual seqlens - Support non-causal attention (all KV visible) - Zero-initialize out/dq for THD to avoid garbage in padding regions - Save per-step cu_seqlens in ctx for backward (avoid recomputation) Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Remove skip gates that blocked THD format with all_gather CP comm type. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…seqlens_q_padded The interleaved valid mask computation assumed cu_seqlens_q_padded starts at 0. With the CP offset-based approach, cu_seqlens_q_padded can start at a non-zero offset, causing a size mismatch. Use absolute positions from cu_seqlens_q_padded to build the valid mask instead. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1164a15 to
b4db9eb
Compare
for more information, see https://pre-commit.ci
sudhakarsingh27
commented
Apr 7, 2026
| if qkv_format == "thd": | ||
| # [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d] | ||
| chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) | ||
| k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( |
Collaborator
Author
There was a problem hiding this comment.
This reorder_seq_chunks_after_a2a_before_attn_thd and the other releated method are not "a2a" specific now, rename them to something like dualchunk_to_contiguous_order_thd and the other one contiguous_to_dualchunk_order_thd
| if use_fused_attention and causal and "bottom_right" not in attn_mask_type: | ||
| attn_mask_type = attn_mask_type + "_bottom_right" | ||
| if qkv_format != "thd": | ||
| attn_mask_type = attn_mask_type + "_bottom_right" |
Collaborator
Author
There was a problem hiding this comment.
Why only for non THD cases, we're creating *_bottom_right masks? Shouldn't THD cases also have it?
Edit: for THD, there are per step masks but those masks are same for each step i.e. "padding_causal_bottom_right" for both steps if causal otherwise "padding" both steps for THD, so maybe there's no need to distinguish these two paths separately.
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Show resolved
Hide resolved
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: