From db44fc2b939e0839875ddf914b7db699fa4f80e5 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 7 Apr 2026 11:32:18 -0700 Subject: [PATCH 1/6] [PyTorch][CP] Fix THD AllGather CP: offset-based approach with proper cu_seqlens - Use per-step cu_seqlens_q_padded to select Q chunks instead of tensor slicing - Use padded cu_seqlens_kv for K/V reordering (ensures divisibility) - Add cu_seqlens_kv and cu_seqlens_kv_padded to AllGather function signature - Compute per-step Q and KV cu_seqlens correctly from actual seqlens - Support non-causal attention (all KV visible) - Zero-initialize out/dq for THD to avoid garbage in padding regions - Save per-step cu_seqlens in ctx for backward (avoid recomputation) Signed-off-by: Sudhakar Singh --- .../dot_product_attention/context_parallel.py | 457 +++++++++++++----- 1 file changed, 338 insertions(+), 119 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..6bc0e8c050 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2805,9 +2805,11 @@ def forward( k, v, cu_seqlens_q, + cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_padded, + cu_seqlens_kv_padded, dropout_p, softmax_scale, qkv_format, @@ -2834,9 +2836,16 @@ def forward( causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type - assert not padding, f"{attn_mask_type} mask type is not supported!" + if qkv_format == "thd": + # THD always uses padding mask types; per-step masks set internally + assert padding, ( + f"THD format requires padding mask type, got {attn_mask_type}!" + ) + else: + assert not padding, f"{attn_mask_type} mask type is not supported!" if use_fused_attention and causal and "bottom_right" not in attn_mask_type: - attn_mask_type = attn_mask_type + "_bottom_right" + if qkv_format != "thd": + attn_mask_type = attn_mask_type + "_bottom_right" assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" assert ( @@ -2874,41 +2883,65 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - seq_dim = qkv_format.index("s") - assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 - ), "Sequence length per GPU needs to be divisible by 2!" + if qkv_format == "thd": + # Save original global cu_seqlens before division + cu_seqlens_q_original = cu_seqlens_q.clone() + else: + seq_dim = qkv_format.index("s") + assert ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + # Divide by 2*cp_size to get per-chunk values max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) - if use_fused_attention or qkv_format == "thd": + if use_fused_attention and qkv_format != "thd": cu_seqlens_q = cu_seqlens_q // (2 * cp_size) if cu_seqlens_q_padded is not None and qkv_format == "thd": cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) else: cu_seqlens_q_padded = None - # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] - q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] - k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + if qkv_format != "thd": + # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] + q = q.view( + *q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :] + ) + # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] + k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] - # [s, b, h, d] -> [cp, s, b, h, d] + # AllGather K/V across CP ranks k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] - k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) - v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) - k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) - v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - k_ag = k_ag.view(-1, *k.shape[1:]) - v_ag = v_ag.view(-1, *v.shape[1:]) + if qkv_format == "thd": + # [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d] + # Use padded cu_seqlens since reorder computes slice boundaries via integer + # division by 2*cp_size, which requires divisible values. + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn( + cp_size, k.device + ) + k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + k_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size + ) + v_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + v_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size + ) + else: + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn( + cp_size, k.device + ) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + cp_stream.wait_stream(torch.cuda.current_stream()) # create two streams to resolve wave quantization issue of Flash Attn in each step @@ -2921,39 +2954,124 @@ def forward( out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] - out = torch.empty_like(q) + out = torch.zeros_like(q) if qkv_format == "thd" else torch.empty_like(q) max_logit_per_step = [None, None] max_logit = None + # Pre-compute THD-specific per-step cu_seqlens + if qkv_format == "thd": + # Rank-level padded offsets (2 chunks per sequence on this rank) + cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2 + + # Per-step Q cu_seqlens (non-padded): different per step since different + # chunks may have different valid token counts for non-divisible seqlens. + thd_cu_seqlens_q_per_step = [ + get_cu_seqlens_on_cp_rank( + cu_seqlens_q_original, cu_seqlens_q_padded_rank, + cp_size, rank, True, False, + ), + get_cu_seqlens_on_cp_rank( + cu_seqlens_q_original, cu_seqlens_q_padded_rank, + cp_size, rank, False, True, + ), + ] + + # Per-step Q cu_seqlens_padded: offset-based approach — pass full Q tensor + # and vary cu_seqlens_q_padded to point kernel at the correct chunk. + # cuDNN uses back-padding (valid tokens at beginning of padded allocation). + padded_chunk_sizes = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] + actual_seqlens = cu_seqlens_q_original[1:] - cu_seqlens_q_original[:-1] + # Step 0: kernel reads from start of each seq's 2-chunk allocation (first chunk) + # Step 1: kernel reads from midpoint of each seq's allocation (second chunk) + thd_cu_seqlens_q_padded_per_step = [cu_seqlens_q_padded_rank, None] + thd_cu_seqlens_q_padded_per_step[1] = cu_seqlens_q_padded_rank.clone() + thd_cu_seqlens_q_padded_per_step[1][:-1] += padded_chunk_sizes + + # Per-step KV cu_seqlens (non-padded): how many actual KV tokens are + # visible for each sequence. + thd_cu_seqlens_kv_per_step = [None, None] + for step_idx in range(2): + if causal: + # Causal: visible KV covers chunks 0..chunk_id + chunk_id = local_seq_chunk_ids[step_idx] + visible_padded = padded_chunk_sizes * (chunk_id + 1) + visible_actual = torch.minimum(actual_seqlens, visible_padded) + cs = torch.zeros_like(cu_seqlens_q_original) + cs[1:] = visible_actual.cumsum(0) + thd_cu_seqlens_kv_per_step[step_idx] = cs + else: + # Non-causal: all KV tokens visible + thd_cu_seqlens_kv_per_step[step_idx] = cu_seqlens_q_original.clone() + + if causal: + # Q is always the last chunk in the visible KV range, + # so bottom_right alignment is always correct. + thd_attn_mask_type_per_step = [ + "padding_causal_bottom_right", + "padding_causal_bottom_right", + ] + else: + thd_attn_mask_type_per_step = ["padding", "padding"] + for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() - kv_seq_range_per_step[i], window_size_per_step[i] = ( - get_kv_seq_info_after_all_gather( - local_seq_chunk_ids[i], - cp_size, - max_seqlen_q, - max_seqlen_kv, - window_size, - causal, + if qkv_format in ["bshd", "sbhd"]: + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_ = q.select(seq_dim, i).contiguous() + kv_seq_range_per_step[i], window_size_per_step[i] = ( + get_kv_seq_info_after_all_gather( + local_seq_chunk_ids[i], + cp_size, + max_seqlen_q, + max_seqlen_kv, + window_size, + causal, + ) ) - ) - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i][0], - kv_seq_range_per_step[i][1], - ) - max_seqlen_kv_ = seq_end_idx - seq_start_idx - if use_fused_attention or qkv_format == "thd": - cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( - k.shape[1], max_seqlen_kv_, k.device + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], ) - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + max_seqlen_kv_ = seq_end_idx - seq_start_idx + if use_fused_attention: + cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( + k.shape[1], max_seqlen_kv_, k.device + ) + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + elif qkv_format == "thd": + # THD: pass full Q with per-step cu_seqlens_q_padded to select chunk + q_ = q + k_ = k_ag + v_ = v_ag + chunk_id = local_seq_chunk_ids[i] + if causal: + max_seqlen_kv_ = max_seqlen_kv * (chunk_id + 1) + else: + max_seqlen_kv_ = max_seqlen_kv * (2 * cp_size) + cu_seqlens_kv_per_step[i] = thd_cu_seqlens_kv_per_step[i] + # Window size + if window_size is None: + window_size_per_step[i] = ( + (-1, 0) if causal else (-1, -1) + ) + else: + window_size_per_step[i] = window_size if use_fused_attention: + # Set per-step parameters for THD vs bshd/sbhd + if qkv_format == "thd": + attn_mask_type_ = thd_attn_mask_type_per_step[i] + cu_seqlens_q_ = thd_cu_seqlens_q_per_step[i] + cu_seqlens_q_padded_ = thd_cu_seqlens_q_padded_per_step[i] + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + else: + cu_seqlens_q_ = cu_seqlens_q + attn_mask_type_ = attn_mask_type + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_per_step[i] ( out_per_step[i], [softmax_lse_per_step[i], rng_states[i]], @@ -2962,7 +3080,7 @@ def forward( is_training, max_seqlen_q, max_seqlen_kv_, - cu_seqlens_q, + cu_seqlens_q_, cu_seqlens_kv_per_step[i], q_, k_, @@ -2972,11 +3090,11 @@ def forward( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, + attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, attn_bias=attn_bias, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, window_size=window_size_per_step[i], return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), @@ -2988,7 +3106,7 @@ def forward( True, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q, + cu_seqlens_q=thd_cu_seqlens_q_per_step[i] if qkv_format == "thd" else cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv_, @@ -3025,6 +3143,18 @@ def forward( out[:, i - 1].copy_(out_per_step[i - 1]) elif qkv_format == "sbhd": out[i - 1].copy_(out_per_step[i - 1]) + elif qkv_format == "thd": + # Copy valid token ranges from this step's output. + # Each step writes at different positions (no overlap, no correction needed). + step_padded = thd_cu_seqlens_q_padded_per_step[i - 1] + step_valid = thd_cu_seqlens_q_per_step[i - 1] + batch_size = step_valid.shape[0] - 1 + for b in range(batch_size): + s = step_padded[b].item() + sz = (step_valid[b + 1] - step_valid[b]).item() + if sz > 0: + out[s : s + sz].copy_(out_per_step[i - 1][s : s + sz]) + if return_max_logit: max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) @@ -3034,7 +3164,9 @@ def forward( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group ) - if use_fused_attention: + if qkv_format == "thd": + pass # out is already [t_rank, h, d], no reshape needed + elif use_fused_attention: if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) elif qkv_format == "sbhd": @@ -3068,6 +3200,12 @@ def forward( ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 + if qkv_format == "thd": + ctx.thd_attn_mask_type_per_step = thd_attn_mask_type_per_step + ctx.max_seqlen_kv = max_seqlen_kv + ctx.cu_seqlens_kv_padded = cu_seqlens_kv_padded + ctx.thd_cu_seqlens_q_per_step = thd_cu_seqlens_q_per_step + ctx.thd_cu_seqlens_q_padded_per_step = thd_cu_seqlens_q_padded_per_step nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if return_max_logit: return out, max_logit @@ -3089,11 +3227,12 @@ def backward(ctx, dout, *_args): kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step - seq_dim = ctx.qkv_format.index("s") + if ctx.qkv_format != "thd": + seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format dout = dout.view(q.shape) - dq = torch.empty_like(q) + dq = torch.zeros_like(q) if ctx.qkv_format == "thd" else torch.empty_like(q) dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) dv = torch.zeros_like(dk) dq_per_step = [None, None] @@ -3105,19 +3244,40 @@ def backward(ctx, dout, *_args): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # [s, b, h, d] -> [cp, s, b, h, d] + # AllGather K/V across CP ranks k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] - k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) - v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) - k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) - v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - k_ag = k_ag.view(-1, *k.shape[1:]) - v_ag = v_ag.view(-1, *v.shape[1:]) + if ctx.qkv_format == "thd": + cu_seqlens_kv_padded = ctx.cu_seqlens_kv_padded + thd_cu_seqlens_q_per_step = ctx.thd_cu_seqlens_q_per_step + cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2 + + # [cp*t, h, d] -> reorder to contiguous per-sequence order + # Use padded cu_seqlens (divisible by 2*cp_size) for correct reorder + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn( + cp_size, k.device + ) + k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + k_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size + ) + v_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + v_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size + ) + + thd_cu_seqlens_q_padded_per_step = ctx.thd_cu_seqlens_q_padded_per_step + else: + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn( + cp_size, k.device + ) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] @@ -3156,25 +3316,50 @@ def backward(ctx, dout, *_args): for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i][0], - kv_seq_range_per_step[i][1], - ) - max_seqlen_kv = seq_end_idx - seq_start_idx - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] - out_ = out_per_step[i] - dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) + if ctx.qkv_format == "thd": + # THD: pass full Q/dout with per-step cu_seqlens_q_padded + q_ = q + k_ = k_ag + v_ = v_ag + chunk_id = local_seq_chunk_ids[i] + causal = "causal" in ctx.attn_mask_type + if causal: + max_seqlen_kv = ctx.max_seqlen_kv * (chunk_id + 1) + else: + max_seqlen_kv = ctx.max_seqlen_kv * (2 * cp_size) + out_ = out_per_step[i] + dout_ = dout + else: + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_ = q.select(seq_dim, i).contiguous() + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv = seq_end_idx - seq_start_idx + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + out_ = out_per_step[i] + dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: + # Set per-step parameters for THD + if ctx.qkv_format == "thd": + attn_mask_type_ = ctx.thd_attn_mask_type_per_step[i] + cu_seqlens_q_ = thd_cu_seqlens_q_per_step[i] + cu_seqlens_q_padded_ = thd_cu_seqlens_q_padded_per_step[i] + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + else: + cu_seqlens_q_ = cu_seqlens_q + attn_mask_type_ = ctx.attn_mask_type + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_per_step[i] aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, - cu_seqlens_q, + cu_seqlens_q_, cu_seqlens_kv_per_step[i], q_, k_, @@ -3185,12 +3370,12 @@ def backward(ctx, dout, *_args): TE_DType[dout.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, - attn_mask_type=ctx.attn_mask_type, + attn_mask_type=attn_mask_type_, attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], deterministic=ctx.deterministic, @@ -3204,7 +3389,7 @@ def backward(ctx, dout, *_args): False, ctx.use_flash_attn_3, ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q, + cu_seqlens_q=thd_cu_seqlens_q_per_step[i] if ctx.qkv_format == "thd" else cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=max_seqlen_kv, @@ -3236,44 +3421,78 @@ def backward(ctx, dout, *_args): if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - if ctx.qkv_format == "bshd": - dq[:, i - 1].copy_(dq_per_step[i - 1]) - elif ctx.qkv_format == "sbhd": - dq[i - 1].copy_(dq_per_step[i - 1]) - # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] - dk_per_step[i - 1], dv_per_step[i - 1] = [ - x.movedim(seq_dim, 0).contiguous() - for x in [dk_per_step[i - 1], dv_per_step[i - 1]] - ] - # wait until dkv update of last step is done - if i > 1: - flash_attn_streams[i - 1].wait_event(dkv_update_done) - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i - 1][0], - kv_seq_range_per_step[i - 1][1], - ) - dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) - dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) - if i < len(local_seq_chunk_ids): - flash_attn_streams[i - 1].record_event(dkv_update_done) + if ctx.qkv_format == "thd": + # dQ: copy valid token ranges from this step's dQ + step_padded = thd_cu_seqlens_q_padded_per_step[i - 1] + step_valid = thd_cu_seqlens_q_per_step[i - 1] + batch_size = step_valid.shape[0] - 1 + for b in range(batch_size): + s = step_padded[b].item() + sz = (step_valid[b + 1] - step_valid[b]).item() + if sz > 0: + dq[s : s + sz].copy_(dq_per_step[i - 1][s : s + sz]) + # dK/dV: add full tensor (kernel zeros non-valid positions) + if i > 1: + flash_attn_streams[i - 1].wait_event(dkv_update_done) + dk.add_(dk_per_step[i - 1]) + dv.add_(dv_per_step[i - 1]) + if i < len(local_seq_chunk_ids): + flash_attn_streams[i - 1].record_event(dkv_update_done) + else: + if ctx.qkv_format == "bshd": + dq[:, i - 1].copy_(dq_per_step[i - 1]) + elif ctx.qkv_format == "sbhd": + dq[i - 1].copy_(dq_per_step[i - 1]) + # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] + dk_per_step[i - 1], dv_per_step[i - 1] = [ + x.movedim(seq_dim, 0).contiguous() + for x in [dk_per_step[i - 1], dv_per_step[i - 1]] + ] + # wait until dkv update of last step is done + if i > 1: + flash_attn_streams[i - 1].wait_event(dkv_update_done) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i - 1][0], + kv_seq_range_per_step[i - 1][1], + ) + dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) + dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) + if i < len(local_seq_chunk_ids): + flash_attn_streams[i - 1].record_event(dkv_update_done) torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] - dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) - dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) - dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) - dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - dk = dk.view(-1, *dk.shape[-3:]) - dv = dv.view(-1, *dv.shape[-3:]) - dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) - dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) - - dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) - dk = dk.movedim(0, seq_dim).contiguous() - dv = dv.movedim(0, seq_dim).contiguous() + if ctx.qkv_format == "thd": + # Reverse-reorder dK/dV from contiguous order back to dual-chunk order, + # then reduce-scatter across CP ranks. + # Use padded cu_seqlens for correct slice boundaries. + dk = reorder_seq_chunks_before_a2a_after_attn_thd( + dk, cu_seqlens_kv_padded, cp_size + ) + dv = reorder_seq_chunks_before_a2a_after_attn_thd( + dv, cu_seqlens_kv_padded, cp_size + ) + dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) + dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + # dQ is already [t_rank, h, d], no reshape needed + else: + # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] + dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) + dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn( + cp_size, dk.device + ) + dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) + dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + dk = dk.view(-1, *dk.shape[-3:]) + dv = dv.view(-1, *dv.shape[-3:]) + dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) + dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + + dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) + dk = dk.movedim(0, seq_dim).contiguous() + dv = dv.movedim(0, seq_dim).contiguous() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( @@ -3298,6 +3517,8 @@ def backward(ctx, dout, *_args): None, None, None, + None, + None, ) @@ -4129,8 +4350,6 @@ def attn_forward_func_with_cp( ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": - args.pop(5) - args.pop(8) args += [window_size, cp_group, cp_stream, use_flash_attn_3] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": From 1a5ca4c4c2ab7d0e3967d4d75375ec83af95ce5e Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 7 Apr 2026 11:32:24 -0700 Subject: [PATCH 2/6] [PyTorch][CP] Enable THD+all_gather tests in test_attention_with_cp Remove skip gates that blocked THD format with all_gather CP comm type. Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention_with_cp.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 5aaf67061b..b5fc364df3 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -99,8 +99,6 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if cp_comm_type == "a2a+p2p": pytest.skip( "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" @@ -267,8 +265,6 @@ def test_cp_with_fused_attention( if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if cp_comm_type == "a2a+p2p": pytest.skip( "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" From b4db9eb7e9d3ab19379ffda500292123f4d5423b Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 7 Apr 2026 11:32:31 -0700 Subject: [PATCH 3/6] [PyTorch][Fused Attn] Fix max_logit masking for non-zero-starting cu_seqlens_q_padded The interleaved valid mask computation assumed cu_seqlens_q_padded starts at 0. With the CP offset-based approach, cu_seqlens_q_padded can start at a non-zero offset, causing a size mismatch. Use absolute positions from cu_seqlens_q_padded to build the valid mask instead. Signed-off-by: Sudhakar Singh --- .../pytorch/cpp_extensions/fused_attn.py | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 06bfb6ef3c..52e42bdc0a 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -317,7 +317,6 @@ def fused_attn_fwd( raise ValueError(f"Unsupported backend {fused_attention_backend}") # execute kernel - output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, @@ -377,23 +376,18 @@ def fused_attn_fwd( if cu_seqlens_q_padded is not None: # For THD + pad_between_seqs=True + non-sm120 + cuDNN>9.6, Max tensor is [tq, h, 1] # and padding positions could be uninitialized. Exclude those padded positions when - # computing max_logit. + # computing max_logit. Use absolute positions from cu_seqlens_q_padded to handle + # cases where cu_seqlens_q_padded may not start at 0 (e.g. CP offset-based approach). actual_seqlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to( device=max_tensor.device ) - padded_seqlens = (cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]).to( - device=max_tensor.device - ) - pad_lens = (padded_seqlens - actual_seqlens).to(device=max_tensor.device) - b = pad_lens.shape[0] - - # Stack [actual, pad] per batch into counts: e.g. [3,1, 3,1, 2,2, 7,1] - counts = torch.stack([actual_seqlens, pad_lens], dim=1).flatten() - # Tile [T, F] per sequence: [T,F, T,F, T,F, T,F] - values = torch.tensor([True, False], device=max_tensor.device).repeat(b) - # Expand: T×3, F×1, T×3, F×1, T×2, F×2, T×7, F×1 → TTTF|TTTF|TTFF|TTTTTTTF - valid = torch.repeat_interleave(values, counts) - # Finally, replace invalid (F) positions with -inf + tq = max_tensor.shape[0] + valid = torch.zeros(tq, dtype=torch.bool, device=max_tensor.device) + b = actual_seqlens.shape[0] + for b_idx in range(b): + start = cu_seqlens_q_padded[b_idx].item() + n_valid = actual_seqlens[b_idx].item() + valid[start : start + n_valid] = True max_tensor = max_tensor.masked_fill(~valid.view(-1, 1, 1), float("-inf")) # Max -> max_logit [h] From 7491ab6c3c11e9604fa42d2502889f374bb38992 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:38:30 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../dot_product_attention/context_parallel.py | 68 +++++++++---------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 6bc0e8c050..7d33014f35 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2838,9 +2838,7 @@ def forward( padding = "padding" in attn_mask_type if qkv_format == "thd": # THD always uses padding mask types; per-step masks set internally - assert padding, ( - f"THD format requires padding mask type, got {attn_mask_type}!" - ) + assert padding, f"THD format requires padding mask type, got {attn_mask_type}!" else: assert not padding, f"{attn_mask_type} mask type is not supported!" if use_fused_attention and causal and "bottom_right" not in attn_mask_type: @@ -2906,9 +2904,7 @@ def forward( if qkv_format != "thd": # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] - q = q.view( - *q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :] - ) + q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] @@ -2920,9 +2916,7 @@ def forward( # [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d] # Use padded cu_seqlens since reorder computes slice boundaries via integer # division by 2*cp_size, which requires divisible values. - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn( - cp_size, k.device - ) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( k_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size ) @@ -2933,9 +2927,7 @@ def forward( # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn( - cp_size, k.device - ) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] @@ -2967,12 +2959,20 @@ def forward( # chunks may have different valid token counts for non-divisible seqlens. thd_cu_seqlens_q_per_step = [ get_cu_seqlens_on_cp_rank( - cu_seqlens_q_original, cu_seqlens_q_padded_rank, - cp_size, rank, True, False, + cu_seqlens_q_original, + cu_seqlens_q_padded_rank, + cp_size, + rank, + True, + False, ), get_cu_seqlens_on_cp_rank( - cu_seqlens_q_original, cu_seqlens_q_padded_rank, - cp_size, rank, False, True, + cu_seqlens_q_original, + cu_seqlens_q_padded_rank, + cp_size, + rank, + False, + True, ), ] @@ -3055,9 +3055,7 @@ def forward( cu_seqlens_kv_per_step[i] = thd_cu_seqlens_kv_per_step[i] # Window size if window_size is None: - window_size_per_step[i] = ( - (-1, 0) if causal else (-1, -1) - ) + window_size_per_step[i] = (-1, 0) if causal else (-1, -1) else: window_size_per_step[i] = window_size if use_fused_attention: @@ -3106,7 +3104,11 @@ def forward( True, use_flash_attn_3, qkv_format, - cu_seqlens_q=thd_cu_seqlens_q_per_step[i] if qkv_format == "thd" else cu_seqlens_q, + cu_seqlens_q=( + thd_cu_seqlens_q_per_step[i] + if qkv_format == "thd" + else cu_seqlens_q + ), cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv_, @@ -3255,9 +3257,7 @@ def backward(ctx, dout, *_args): # [cp*t, h, d] -> reorder to contiguous per-sequence order # Use padded cu_seqlens (divisible by 2*cp_size) for correct reorder - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn( - cp_size, k.device - ) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( k_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size ) @@ -3270,9 +3270,7 @@ def backward(ctx, dout, *_args): # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn( - cp_size, k.device - ) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] @@ -3389,7 +3387,11 @@ def backward(ctx, dout, *_args): False, ctx.use_flash_attn_3, ctx.qkv_format, - cu_seqlens_q=thd_cu_seqlens_q_per_step[i] if ctx.qkv_format == "thd" else cu_seqlens_q, + cu_seqlens_q=( + thd_cu_seqlens_q_per_step[i] + if ctx.qkv_format == "thd" + else cu_seqlens_q + ), cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=max_seqlen_kv, @@ -3466,12 +3468,8 @@ def backward(ctx, dout, *_args): # Reverse-reorder dK/dV from contiguous order back to dual-chunk order, # then reduce-scatter across CP ranks. # Use padded cu_seqlens for correct slice boundaries. - dk = reorder_seq_chunks_before_a2a_after_attn_thd( - dk, cu_seqlens_kv_padded, cp_size - ) - dv = reorder_seq_chunks_before_a2a_after_attn_thd( - dv, cu_seqlens_kv_padded, cp_size - ) + dk = reorder_seq_chunks_before_a2a_after_attn_thd(dk, cu_seqlens_kv_padded, cp_size) + dv = reorder_seq_chunks_before_a2a_after_attn_thd(dv, cu_seqlens_kv_padded, cp_size) dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) # dQ is already [t_rank, h, d], no reshape needed @@ -3479,9 +3477,7 @@ def backward(ctx, dout, *_args): # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn( - cp_size, dk.device - ) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] From b957725fef67134f0ac299b63494107913f29669 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 10 Apr 2026 13:55:00 -0700 Subject: [PATCH 5/6] some cleanup of ag+thd impl and gate e e te test for flash+ag+thd Signed-off-by: Sudhakar Singh --- .../attention/test_attention_with_cp.py | 5 ++ .../dot_product_attention/context_parallel.py | 46 +++++++------------ 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index b5fc364df3..bbfecd6969 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -99,6 +99,11 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") if qkv_format == "thd": + if cp_comm_type == "all_gather": + pytest.skip( + "FlashAttention does not support THD padding; use FusedAttention for" + " THD+all_gather CP." + ) if cp_comm_type == "a2a+p2p": pytest.skip( "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 6bc0e8c050..6bd40fb7e3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2844,8 +2844,7 @@ def forward( else: assert not padding, f"{attn_mask_type} mask type is not supported!" if use_fused_attention and causal and "bottom_right" not in attn_mask_type: - if qkv_format != "thd": - attn_mask_type = attn_mask_type + "_bottom_right" + attn_mask_type = attn_mask_type + "_bottom_right" assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" assert ( @@ -2888,6 +2887,7 @@ def forward( if qkv_format == "thd": # Save original global cu_seqlens before division cu_seqlens_q_original = cu_seqlens_q.clone() + cu_seqlens_kv_original = cu_seqlens_kv.clone() else: seq_dim = qkv_format.index("s") assert ( @@ -2899,8 +2899,9 @@ def forward( max_seqlen_kv = max_seqlen_kv // (2 * cp_size) if use_fused_attention and qkv_format != "thd": cu_seqlens_q = cu_seqlens_q // (2 * cp_size) - if cu_seqlens_q_padded is not None and qkv_format == "thd": + if qkv_format == "thd": cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) + cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size) else: cu_seqlens_q_padded = None @@ -2924,10 +2925,10 @@ def forward( cp_size, k.device ) k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( - k_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size + k_ag, cu_seqlens_kv_padded * 2 * cp_size, chunk_ids_for_kv_ag, cp_size ) v_ag = reorder_seq_chunks_after_a2a_before_attn_thd( - v_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size + v_ag, cu_seqlens_kv_padded * 2 * cp_size, chunk_ids_for_kv_ag, cp_size ) else: # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] @@ -2979,39 +2980,27 @@ def forward( # Per-step Q cu_seqlens_padded: offset-based approach — pass full Q tensor # and vary cu_seqlens_q_padded to point kernel at the correct chunk. # cuDNN uses back-padding (valid tokens at beginning of padded allocation). - padded_chunk_sizes = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] - actual_seqlens = cu_seqlens_q_original[1:] - cu_seqlens_q_original[:-1] + padded_chunk_sizes_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] + actual_seqlens_kv = cu_seqlens_kv_original[1:] - cu_seqlens_kv_original[:-1] # Step 0: kernel reads from start of each seq's 2-chunk allocation (first chunk) # Step 1: kernel reads from midpoint of each seq's allocation (second chunk) thd_cu_seqlens_q_padded_per_step = [cu_seqlens_q_padded_rank, None] thd_cu_seqlens_q_padded_per_step[1] = cu_seqlens_q_padded_rank.clone() - thd_cu_seqlens_q_padded_per_step[1][:-1] += padded_chunk_sizes + thd_cu_seqlens_q_padded_per_step[1][:-1] += padded_chunk_sizes_q # Per-step KV cu_seqlens (non-padded): how many actual KV tokens are # visible for each sequence. - thd_cu_seqlens_kv_per_step = [None, None] + padded_chunk_sizes_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1] + thd_cu_seqlens_kv_per_step = [cu_seqlens_q_original.clone(), cu_seqlens_q_original.clone()] for step_idx in range(2): if causal: # Causal: visible KV covers chunks 0..chunk_id chunk_id = local_seq_chunk_ids[step_idx] - visible_padded = padded_chunk_sizes * (chunk_id + 1) - visible_actual = torch.minimum(actual_seqlens, visible_padded) - cs = torch.zeros_like(cu_seqlens_q_original) + visible_padded = padded_chunk_sizes_kv * (chunk_id + 1) + visible_actual = torch.minimum(actual_seqlens_kv, visible_padded) + cs = torch.zeros_like(cu_seqlens_kv_original) cs[1:] = visible_actual.cumsum(0) thd_cu_seqlens_kv_per_step[step_idx] = cs - else: - # Non-causal: all KV tokens visible - thd_cu_seqlens_kv_per_step[step_idx] = cu_seqlens_q_original.clone() - - if causal: - # Q is always the last chunk in the visible KV range, - # so bottom_right alignment is always correct. - thd_attn_mask_type_per_step = [ - "padding_causal_bottom_right", - "padding_causal_bottom_right", - ] - else: - thd_attn_mask_type_per_step = ["padding", "padding"] for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): @@ -3063,13 +3052,11 @@ def forward( if use_fused_attention: # Set per-step parameters for THD vs bshd/sbhd if qkv_format == "thd": - attn_mask_type_ = thd_attn_mask_type_per_step[i] cu_seqlens_q_ = thd_cu_seqlens_q_per_step[i] cu_seqlens_q_padded_ = thd_cu_seqlens_q_padded_per_step[i] cu_seqlens_kv_padded_ = cu_seqlens_kv_padded else: cu_seqlens_q_ = cu_seqlens_q - attn_mask_type_ = attn_mask_type cu_seqlens_q_padded_ = cu_seqlens_q_padded cu_seqlens_kv_padded_ = cu_seqlens_kv_per_step[i] ( @@ -3090,7 +3077,7 @@ def forward( attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type_, + attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, cu_seqlens_q_padded=cu_seqlens_q_padded_, @@ -3201,7 +3188,6 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 if qkv_format == "thd": - ctx.thd_attn_mask_type_per_step = thd_attn_mask_type_per_step ctx.max_seqlen_kv = max_seqlen_kv ctx.cu_seqlens_kv_padded = cu_seqlens_kv_padded ctx.thd_cu_seqlens_q_per_step = thd_cu_seqlens_q_per_step @@ -3346,7 +3332,7 @@ def backward(ctx, dout, *_args): if ctx.use_fused_attention: # Set per-step parameters for THD if ctx.qkv_format == "thd": - attn_mask_type_ = ctx.thd_attn_mask_type_per_step[i] + attn_mask_type_ = ctx.attn_mask_type cu_seqlens_q_ = thd_cu_seqlens_q_per_step[i] cu_seqlens_q_padded_ = thd_cu_seqlens_q_padded_per_step[i] cu_seqlens_kv_padded_ = cu_seqlens_kv_padded From 0b4874613a899e7f8809bee995d30a0f35f62780 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Apr 2026 20:57:17 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../attention/dot_product_attention/context_parallel.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 3af2c3454e..870e7c5e6d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2991,7 +2991,10 @@ def forward( # Per-step KV cu_seqlens (non-padded): how many actual KV tokens are # visible for each sequence. padded_chunk_sizes_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1] - thd_cu_seqlens_kv_per_step = [cu_seqlens_q_original.clone(), cu_seqlens_q_original.clone()] + thd_cu_seqlens_kv_per_step = [ + cu_seqlens_q_original.clone(), + cu_seqlens_q_original.clone(), + ] for step_idx in range(2): if causal: # Causal: visible KV covers chunks 0..chunk_id