diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 5aaf67061b..bbfecd6969 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -100,7 +100,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip("CP implementation with KV all-gather does not support bias yet!") if qkv_format == "thd": if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") + pytest.skip( + "FlashAttention does not support THD padding; use FusedAttention for" + " THD+all_gather CP." + ) if cp_comm_type == "a2a+p2p": pytest.skip( "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" @@ -267,8 +270,6 @@ def test_cp_with_fused_attention( if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd": - if cp_comm_type == "all_gather": - pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if cp_comm_type == "a2a+p2p": pytest.skip( "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..870e7c5e6d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -2805,9 +2805,11 @@ def forward( k, v, cu_seqlens_q, + cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_padded, + cu_seqlens_kv_padded, dropout_p, softmax_scale, qkv_format, @@ -2834,7 +2836,11 @@ def forward( causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type - assert not padding, f"{attn_mask_type} mask type is not supported!" + if qkv_format == "thd": + # THD always uses padding mask types; per-step masks set internally + assert padding, f"THD format requires padding mask type, got {attn_mask_type}!" + else: + assert not padding, f"{attn_mask_type} mask type is not supported!" if use_fused_attention and causal and "bottom_right" not in attn_mask_type: attn_mask_type = attn_mask_type + "_bottom_right" assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" @@ -2874,41 +2880,61 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - seq_dim = qkv_format.index("s") - assert ( - q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 - ), "Sequence length per GPU needs to be divisible by 2!" + if qkv_format == "thd": + # Save original global cu_seqlens before division + cu_seqlens_q_original = cu_seqlens_q.clone() + cu_seqlens_kv_original = cu_seqlens_kv.clone() + else: + seq_dim = qkv_format.index("s") + assert ( + q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 + ), "Sequence length per GPU needs to be divisible by 2!" + # Divide by 2*cp_size to get per-chunk values max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) - if use_fused_attention or qkv_format == "thd": + if use_fused_attention and qkv_format != "thd": cu_seqlens_q = cu_seqlens_q // (2 * cp_size) - if cu_seqlens_q_padded is not None and qkv_format == "thd": + if qkv_format == "thd": cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) + cu_seqlens_kv_padded = cu_seqlens_kv_padded // (2 * cp_size) else: cu_seqlens_q_padded = None - # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] - q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] - k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] + if qkv_format != "thd": + # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] + q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) + # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] + k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] - # [s, b, h, d] -> [cp, s, b, h, d] + # AllGather K/V across CP ranks k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] - k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) - v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) - k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) - v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - k_ag = k_ag.view(-1, *k.shape[1:]) - v_ag = v_ag.view(-1, *v.shape[1:]) + if qkv_format == "thd": + # [cp*t, h, d] -> reorder to contiguous per-sequence order -> [t_full, h, d] + # Use padded cu_seqlens since reorder computes slice boundaries via integer + # division by 2*cp_size, which requires divisible values. + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + k_ag, cu_seqlens_kv_padded * 2 * cp_size, chunk_ids_for_kv_ag, cp_size + ) + v_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + v_ag, cu_seqlens_kv_padded * 2 * cp_size, chunk_ids_for_kv_ag, cp_size + ) + else: + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) + cp_stream.wait_stream(torch.cuda.current_stream()) # create two streams to resolve wave quantization issue of Flash Attn in each step @@ -2921,39 +2947,119 @@ def forward( out_per_step = [None, None] softmax_lse_per_step = [None, None] rng_states = [None, None] - out = torch.empty_like(q) + out = torch.zeros_like(q) if qkv_format == "thd" else torch.empty_like(q) max_logit_per_step = [None, None] max_logit = None + # Pre-compute THD-specific per-step cu_seqlens + if qkv_format == "thd": + # Rank-level padded offsets (2 chunks per sequence on this rank) + cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2 + + # Per-step Q cu_seqlens (non-padded): different per step since different + # chunks may have different valid token counts for non-divisible seqlens. + thd_cu_seqlens_q_per_step = [ + get_cu_seqlens_on_cp_rank( + cu_seqlens_q_original, + cu_seqlens_q_padded_rank, + cp_size, + rank, + True, + False, + ), + get_cu_seqlens_on_cp_rank( + cu_seqlens_q_original, + cu_seqlens_q_padded_rank, + cp_size, + rank, + False, + True, + ), + ] + + # Per-step Q cu_seqlens_padded: offset-based approach — pass full Q tensor + # and vary cu_seqlens_q_padded to point kernel at the correct chunk. + # cuDNN uses back-padding (valid tokens at beginning of padded allocation). + padded_chunk_sizes_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] + actual_seqlens_kv = cu_seqlens_kv_original[1:] - cu_seqlens_kv_original[:-1] + # Step 0: kernel reads from start of each seq's 2-chunk allocation (first chunk) + # Step 1: kernel reads from midpoint of each seq's allocation (second chunk) + thd_cu_seqlens_q_padded_per_step = [cu_seqlens_q_padded_rank, None] + thd_cu_seqlens_q_padded_per_step[1] = cu_seqlens_q_padded_rank.clone() + thd_cu_seqlens_q_padded_per_step[1][:-1] += padded_chunk_sizes_q + + # Per-step KV cu_seqlens (non-padded): how many actual KV tokens are + # visible for each sequence. + padded_chunk_sizes_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1] + thd_cu_seqlens_kv_per_step = [ + cu_seqlens_q_original.clone(), + cu_seqlens_q_original.clone(), + ] + for step_idx in range(2): + if causal: + # Causal: visible KV covers chunks 0..chunk_id + chunk_id = local_seq_chunk_ids[step_idx] + visible_padded = padded_chunk_sizes_kv * (chunk_id + 1) + visible_actual = torch.minimum(actual_seqlens_kv, visible_padded) + cs = torch.zeros_like(cu_seqlens_kv_original) + cs[1:] = visible_actual.cumsum(0) + thd_cu_seqlens_kv_per_step[step_idx] = cs + for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() - kv_seq_range_per_step[i], window_size_per_step[i] = ( - get_kv_seq_info_after_all_gather( - local_seq_chunk_ids[i], - cp_size, - max_seqlen_q, - max_seqlen_kv, - window_size, - causal, + if qkv_format in ["bshd", "sbhd"]: + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_ = q.select(seq_dim, i).contiguous() + kv_seq_range_per_step[i], window_size_per_step[i] = ( + get_kv_seq_info_after_all_gather( + local_seq_chunk_ids[i], + cp_size, + max_seqlen_q, + max_seqlen_kv, + window_size, + causal, + ) ) - ) - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i][0], - kv_seq_range_per_step[i][1], - ) - max_seqlen_kv_ = seq_end_idx - seq_start_idx - if use_fused_attention or qkv_format == "thd": - cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( - k.shape[1], max_seqlen_kv_, k.device + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], ) - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + max_seqlen_kv_ = seq_end_idx - seq_start_idx + if use_fused_attention: + cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( + k.shape[1], max_seqlen_kv_, k.device + ) + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + elif qkv_format == "thd": + # THD: pass full Q with per-step cu_seqlens_q_padded to select chunk + q_ = q + k_ = k_ag + v_ = v_ag + chunk_id = local_seq_chunk_ids[i] + if causal: + max_seqlen_kv_ = max_seqlen_kv * (chunk_id + 1) + else: + max_seqlen_kv_ = max_seqlen_kv * (2 * cp_size) + cu_seqlens_kv_per_step[i] = thd_cu_seqlens_kv_per_step[i] + # Window size + if window_size is None: + window_size_per_step[i] = (-1, 0) if causal else (-1, -1) + else: + window_size_per_step[i] = window_size if use_fused_attention: + # Set per-step parameters for THD vs bshd/sbhd + if qkv_format == "thd": + cu_seqlens_q_ = thd_cu_seqlens_q_per_step[i] + cu_seqlens_q_padded_ = thd_cu_seqlens_q_padded_per_step[i] + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + else: + cu_seqlens_q_ = cu_seqlens_q + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_per_step[i] ( out_per_step[i], [softmax_lse_per_step[i], rng_states[i]], @@ -2962,7 +3068,7 @@ def forward( is_training, max_seqlen_q, max_seqlen_kv_, - cu_seqlens_q, + cu_seqlens_q_, cu_seqlens_kv_per_step[i], q_, k_, @@ -2975,8 +3081,8 @@ def forward( attn_mask_type=attn_mask_type, attn_bias_type=attn_bias_type, attn_bias=attn_bias, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, window_size=window_size_per_step[i], return_max_logit=return_max_logit, cuda_graph=is_graph_capturing(), @@ -2988,7 +3094,11 @@ def forward( True, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q, + cu_seqlens_q=( + thd_cu_seqlens_q_per_step[i] + if qkv_format == "thd" + else cu_seqlens_q + ), cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv_, @@ -3025,6 +3135,18 @@ def forward( out[:, i - 1].copy_(out_per_step[i - 1]) elif qkv_format == "sbhd": out[i - 1].copy_(out_per_step[i - 1]) + elif qkv_format == "thd": + # Copy valid token ranges from this step's output. + # Each step writes at different positions (no overlap, no correction needed). + step_padded = thd_cu_seqlens_q_padded_per_step[i - 1] + step_valid = thd_cu_seqlens_q_per_step[i - 1] + batch_size = step_valid.shape[0] - 1 + for b in range(batch_size): + s = step_padded[b].item() + sz = (step_valid[b + 1] - step_valid[b]).item() + if sz > 0: + out[s : s + sz].copy_(out_per_step[i - 1][s : s + sz]) + if return_max_logit: max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) @@ -3034,7 +3156,9 @@ def forward( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group ) - if use_fused_attention: + if qkv_format == "thd": + pass # out is already [t_rank, h, d], no reshape needed + elif use_fused_attention: if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) elif qkv_format == "sbhd": @@ -3068,6 +3192,11 @@ def forward( ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 + if qkv_format == "thd": + ctx.max_seqlen_kv = max_seqlen_kv + ctx.cu_seqlens_kv_padded = cu_seqlens_kv_padded + ctx.thd_cu_seqlens_q_per_step = thd_cu_seqlens_q_per_step + ctx.thd_cu_seqlens_q_padded_per_step = thd_cu_seqlens_q_padded_per_step nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if return_max_logit: return out, max_logit @@ -3089,11 +3218,12 @@ def backward(ctx, dout, *_args): kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step - seq_dim = ctx.qkv_format.index("s") + if ctx.qkv_format != "thd": + seq_dim = ctx.qkv_format.index("s") qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format dout = dout.view(q.shape) - dq = torch.empty_like(q) + dq = torch.zeros_like(q) if ctx.qkv_format == "thd" else torch.empty_like(q) dk = torch.zeros((k.shape[0] * cp_size, *k.shape[1:]), dtype=k.dtype, device=k.device) dv = torch.zeros_like(dk) dq_per_step = [None, None] @@ -3105,19 +3235,36 @@ def backward(ctx, dout, *_args): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # [s, b, h, d] -> [cp, s, b, h, d] + # AllGather K/V across CP ranks k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] - k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) - v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) - k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) - v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - k_ag = k_ag.view(-1, *k.shape[1:]) - v_ag = v_ag.view(-1, *v.shape[1:]) + if ctx.qkv_format == "thd": + cu_seqlens_kv_padded = ctx.cu_seqlens_kv_padded + thd_cu_seqlens_q_per_step = ctx.thd_cu_seqlens_q_per_step + cu_seqlens_q_padded_rank = cu_seqlens_q_padded * 2 + + # [cp*t, h, d] -> reorder to contiguous per-sequence order + # Use padded cu_seqlens (divisible by 2*cp_size) for correct reorder + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + k_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size + ) + v_ag = reorder_seq_chunks_after_a2a_before_attn_thd( + v_ag, cu_seqlens_kv_padded, chunk_ids_for_kv_ag, cp_size + ) + + thd_cu_seqlens_q_padded_per_step = ctx.thd_cu_seqlens_q_padded_per_step + else: + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] + k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) + v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) + k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) + v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + k_ag = k_ag.view(-1, *k.shape[1:]) + v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) local_seq_chunk_ids = [rank, 2 * cp_size - rank - 1] @@ -3156,25 +3303,50 @@ def backward(ctx, dout, *_args): for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] - # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] - q_ = q.select(seq_dim, i).contiguous() - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i][0], - kv_seq_range_per_step[i][1], - ) - max_seqlen_kv = seq_end_idx - seq_start_idx - k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] - k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] - out_ = out_per_step[i] - dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) + if ctx.qkv_format == "thd": + # THD: pass full Q/dout with per-step cu_seqlens_q_padded + q_ = q + k_ = k_ag + v_ = v_ag + chunk_id = local_seq_chunk_ids[i] + causal = "causal" in ctx.attn_mask_type + if causal: + max_seqlen_kv = ctx.max_seqlen_kv * (chunk_id + 1) + else: + max_seqlen_kv = ctx.max_seqlen_kv * (2 * cp_size) + out_ = out_per_step[i] + dout_ = dout + else: + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_ = q.select(seq_dim, i).contiguous() + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i][0], + kv_seq_range_per_step[i][1], + ) + max_seqlen_kv = seq_end_idx - seq_start_idx + k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] + # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] + k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] + out_ = out_per_step[i] + dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: + # Set per-step parameters for THD + if ctx.qkv_format == "thd": + attn_mask_type_ = ctx.attn_mask_type + cu_seqlens_q_ = thd_cu_seqlens_q_per_step[i] + cu_seqlens_q_padded_ = thd_cu_seqlens_q_padded_per_step[i] + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + else: + cu_seqlens_q_ = cu_seqlens_q + attn_mask_type_ = ctx.attn_mask_type + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_per_step[i] aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, - cu_seqlens_q, + cu_seqlens_q_, cu_seqlens_kv_per_step[i], q_, k_, @@ -3185,12 +3357,12 @@ def backward(ctx, dout, *_args): TE_DType[dout.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, attn_scale=ctx.softmax_scale, dropout=ctx.dropout_p, qkv_layout=qkv_layout, - attn_mask_type=ctx.attn_mask_type, + attn_mask_type=attn_mask_type_, attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], deterministic=ctx.deterministic, @@ -3204,7 +3376,11 @@ def backward(ctx, dout, *_args): False, ctx.use_flash_attn_3, ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q, + cu_seqlens_q=( + thd_cu_seqlens_q_per_step[i] + if ctx.qkv_format == "thd" + else cu_seqlens_q + ), cu_seqlens_kv=cu_seqlens_kv_per_step[i], max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=max_seqlen_kv, @@ -3236,44 +3412,72 @@ def backward(ctx, dout, *_args): if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): - if ctx.qkv_format == "bshd": - dq[:, i - 1].copy_(dq_per_step[i - 1]) - elif ctx.qkv_format == "sbhd": - dq[i - 1].copy_(dq_per_step[i - 1]) - # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] - dk_per_step[i - 1], dv_per_step[i - 1] = [ - x.movedim(seq_dim, 0).contiguous() - for x in [dk_per_step[i - 1], dv_per_step[i - 1]] - ] - # wait until dkv update of last step is done - if i > 1: - flash_attn_streams[i - 1].wait_event(dkv_update_done) - seq_start_idx, seq_end_idx = ( - kv_seq_range_per_step[i - 1][0], - kv_seq_range_per_step[i - 1][1], - ) - dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) - dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) - if i < len(local_seq_chunk_ids): - flash_attn_streams[i - 1].record_event(dkv_update_done) + if ctx.qkv_format == "thd": + # dQ: copy valid token ranges from this step's dQ + step_padded = thd_cu_seqlens_q_padded_per_step[i - 1] + step_valid = thd_cu_seqlens_q_per_step[i - 1] + batch_size = step_valid.shape[0] - 1 + for b in range(batch_size): + s = step_padded[b].item() + sz = (step_valid[b + 1] - step_valid[b]).item() + if sz > 0: + dq[s : s + sz].copy_(dq_per_step[i - 1][s : s + sz]) + # dK/dV: add full tensor (kernel zeros non-valid positions) + if i > 1: + flash_attn_streams[i - 1].wait_event(dkv_update_done) + dk.add_(dk_per_step[i - 1]) + dv.add_(dv_per_step[i - 1]) + if i < len(local_seq_chunk_ids): + flash_attn_streams[i - 1].record_event(dkv_update_done) + else: + if ctx.qkv_format == "bshd": + dq[:, i - 1].copy_(dq_per_step[i - 1]) + elif ctx.qkv_format == "sbhd": + dq[i - 1].copy_(dq_per_step[i - 1]) + # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] + dk_per_step[i - 1], dv_per_step[i - 1] = [ + x.movedim(seq_dim, 0).contiguous() + for x in [dk_per_step[i - 1], dv_per_step[i - 1]] + ] + # wait until dkv update of last step is done + if i > 1: + flash_attn_streams[i - 1].wait_event(dkv_update_done) + seq_start_idx, seq_end_idx = ( + kv_seq_range_per_step[i - 1][0], + kv_seq_range_per_step[i - 1][1], + ) + dk[seq_start_idx:seq_end_idx].add_(dk_per_step[i - 1]) + dv[seq_start_idx:seq_end_idx].add_(dv_per_step[i - 1]) + if i < len(local_seq_chunk_ids): + flash_attn_streams[i - 1].record_event(dkv_update_done) torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] - dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) - dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) - dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) - dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] - dk = dk.view(-1, *dk.shape[-3:]) - dv = dv.view(-1, *dv.shape[-3:]) - dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) - dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) - - dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) - dk = dk.movedim(0, seq_dim).contiguous() - dv = dv.movedim(0, seq_dim).contiguous() + if ctx.qkv_format == "thd": + # Reverse-reorder dK/dV from contiguous order back to dual-chunk order, + # then reduce-scatter across CP ranks. + # Use padded cu_seqlens for correct slice boundaries. + dk = reorder_seq_chunks_before_a2a_after_attn_thd(dk, cu_seqlens_kv_padded, cp_size) + dv = reorder_seq_chunks_before_a2a_after_attn_thd(dv, cu_seqlens_kv_padded, cp_size) + dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) + dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + # dQ is already [t_rank, h, d], no reshape needed + else: + # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] + dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) + dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) + dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) + dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] + dk = dk.view(-1, *dk.shape[-3:]) + dv = dv.view(-1, *dv.shape[-3:]) + dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) + dv, _ = reduce_scatter_along_first_dim(dv, ctx.cp_group) + + dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) + dk = dk.movedim(0, seq_dim).contiguous() + dv = dv.movedim(0, seq_dim).contiguous() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( @@ -3298,6 +3502,8 @@ def backward(ctx, dout, *_args): None, None, None, + None, + None, ) @@ -4129,8 +4335,6 @@ def attn_forward_func_with_cp( ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": - args.pop(5) - args.pop(8) args += [window_size, cp_group, cp_stream, use_flash_attn_3] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 06bfb6ef3c..52e42bdc0a 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -317,7 +317,6 @@ def fused_attn_fwd( raise ValueError(f"Unsupported backend {fused_attention_backend}") # execute kernel - output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, @@ -377,23 +376,18 @@ def fused_attn_fwd( if cu_seqlens_q_padded is not None: # For THD + pad_between_seqs=True + non-sm120 + cuDNN>9.6, Max tensor is [tq, h, 1] # and padding positions could be uninitialized. Exclude those padded positions when - # computing max_logit. + # computing max_logit. Use absolute positions from cu_seqlens_q_padded to handle + # cases where cu_seqlens_q_padded may not start at 0 (e.g. CP offset-based approach). actual_seqlens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to( device=max_tensor.device ) - padded_seqlens = (cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]).to( - device=max_tensor.device - ) - pad_lens = (padded_seqlens - actual_seqlens).to(device=max_tensor.device) - b = pad_lens.shape[0] - - # Stack [actual, pad] per batch into counts: e.g. [3,1, 3,1, 2,2, 7,1] - counts = torch.stack([actual_seqlens, pad_lens], dim=1).flatten() - # Tile [T, F] per sequence: [T,F, T,F, T,F, T,F] - values = torch.tensor([True, False], device=max_tensor.device).repeat(b) - # Expand: T×3, F×1, T×3, F×1, T×2, F×2, T×7, F×1 → TTTF|TTTF|TTFF|TTTTTTTF - valid = torch.repeat_interleave(values, counts) - # Finally, replace invalid (F) positions with -inf + tq = max_tensor.shape[0] + valid = torch.zeros(tq, dtype=torch.bool, device=max_tensor.device) + b = actual_seqlens.shape[0] + for b_idx in range(b): + start = cu_seqlens_q_padded[b_idx].item() + n_valid = actual_seqlens[b_idx].item() + valid[start : start + n_valid] = True max_tensor = max_tensor.masked_fill(~valid.view(-1, 1, 1), float("-inf")) # Max -> max_logit [h]