-
Notifications
You must be signed in to change notification settings - Fork 698
[PyTorch][CP] Add THD format support for AllGather-based Context Parallelism #2829
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sudhakarsingh27
wants to merge
14
commits into
NVIDIA:main
Choose a base branch
from
sudhakarsingh27:cp_thd_swa_with_ag
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+365
−136
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
db44fc2
[PyTorch][CP] Fix THD AllGather CP: offset-based approach with proper…
sudhakarsingh27 1a5ca4c
[PyTorch][CP] Enable THD+all_gather tests in test_attention_with_cp
sudhakarsingh27 b4db9eb
[PyTorch][Fused Attn] Fix max_logit masking for non-zero-starting cu_…
sudhakarsingh27 7491ab6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b957725
some cleanup of ag+thd impl and gate e e te test for flash+ag+thd
sudhakarsingh27 c89173c
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 18e41bd
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 0b48746
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 608106d
improve the logic and remvoe for loop from the code
sudhakarsingh27 4b95130
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 15af3af
Merge branch 'main' of github.com:NVIDIA/TransformerEngine into cp_th…
sudhakarsingh27 5bec5b3
Merge branch 'cp_thd_swa_with_ag' of github.com:sudhakarsingh27/Trans…
sudhakarsingh27 89b1066
AG+THD SWA: extend KV visibility for right window and rename a2a-spec…
sudhakarsingh27 55fc2cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -100,7 +100,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): | |
| pytest.skip("CP implementation with KV all-gather does not support bias yet!") | ||
| if qkv_format == "thd": | ||
| if cp_comm_type == "all_gather": | ||
| pytest.skip("CP implementation with KV all-gather does not support THD format yet!") | ||
| pytest.skip( | ||
| "FlashAttention does not support THD padding; use FusedAttention for" | ||
| " THD+all_gather CP." | ||
| ) | ||
| if cp_comm_type == "a2a+p2p": | ||
| pytest.skip( | ||
| "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" | ||
|
|
@@ -267,8 +270,6 @@ def test_cp_with_fused_attention( | |
| if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": | ||
| pytest.skip("THD format does not support post_scale_bias yet!") | ||
| if qkv_format == "thd": | ||
| if cp_comm_type == "all_gather": | ||
| pytest.skip("CP implementation with KV all-gather does not support THD format yet!") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A general comment - please run the CP file with "test_essential=False" offline because the essential tests may not cover everything. |
||
| if cp_comm_type == "a2a+p2p": | ||
| pytest.skip( | ||
| "CP implementation with QKVO A2A+P2P (Hierarchical A2A) does not support THD format" | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe swap the words a little bit so it doesn't sounds like FlashAttention doesn't support THD, but just our CP implementation with it doesn't? (Also, THD implies padding in our terminology?)