Skip to content
7 changes: 4 additions & 3 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
pytest.skip(
"FlashAttention does not support THD padding; use FusedAttention for"
" THD+all_gather CP."
)
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
Expand Down Expand Up @@ -267,8 +270,6 @@ def test_cp_with_fused_attention(
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
Expand Down
Loading