Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
pytest.skip(
"FlashAttention does not support THD padding; use FusedAttention for"
" THD+all_gather CP."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe swap the words a little bit so it doesn't sounds like FlashAttention doesn't support THD, but just our CP implementation with it doesn't? (Also, THD implies padding in our terminology?)

)
if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
Expand Down Expand Up @@ -267,8 +270,6 @@ def test_cp_with_fused_attention(
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd":
if cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general comment - please run the CP file with "test_essential=False" offline because the essential tests may not cover everything.

if cp_comm_type == "a2a+p2p":
pytest.skip(
"CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format"
Expand Down
Loading
Loading