feat: filter logits by loss_mask before log_probs/entropy computation#1905
Open
Taosheng-ty wants to merge 1 commit into
Open
feat: filter logits by loss_mask before log_probs/entropy computation#1905Taosheng-ty wants to merge 1 commit into
Taosheng-ty wants to merge 1 commit into
Conversation
For multi-turn agent rollouts where tool-result tokens dominate the
response (often >90%), computing log-probs and entropy for all positions
wastes memory and compute — those masked positions contribute zeros to
the loss anyway.
This adds a loss_masks parameter to get_log_probs_and_entropy. When
provided (and cp_size == 1), only positions where mask == 1 go through
the expensive vocab-parallel softmax. Outputs are padded back to the
original response length with zeros so all downstream code (advantages,
sum_of_sample_mean, etc.) works unchanged.
Typical savings for agentic workloads:
- 97% masked tokens → ~30x reduction in softmax compute
- Prevents OOM on long multi-turn samples with large tool outputs
- Communication in vocab-parallel all-reduces drops proportionally
Limitations:
- Only active when cp_size == 1 (falls through to unfiltered path
for context parallelism > 1)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
d9cde49 to
3643dea
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
loss_mask == 1positions before the expensive vocab-parallel softmax inget_log_probs_and_entropy, then pad output back to original length with zeroscp_size == 1(falls through to unfiltered path for context parallelism > 1)Motivation
In agentic workloads (tool-use, multi-turn), the "response" includes large tool-result tokens that are masked out in the loss. Without this change, all those positions still go through the full vocab-parallel softmax (all-reduce MAX + SUM + SUM), wasting memory and compute only to be multiplied by zero downstream.
Typical savings for our workload:
Changes
loss.py: Addloss_masksparam toget_log_probs_and_entropy; filter before compute, pad back afterloss.py: Addloss_masksparam toget_values(accepts but ignores, for signature compatibility withforward_only)loss.py: Passloss_masksinpolicy_loss_functionandsft_loss_functioncallersmodel.py: Passloss_masksinforward_onlypartialDesign
The output shape is unchanged (
[response_length]per sample), so all downstream code (sum_of_sample_mean, advantages,old_log_probs) works without modification. Zeros at masked positions are fine becausesum_of_sample_meanalready multiplies byloss_masks.Limitations
cp_size == 1. With context parallelism > 1, the response is split across ranks in a zigzag pattern, and mask alignment requires additional offset logic. This can be added in a follow-up.Test plan
torch.zeros+ indexed assignment produces correctIndexPutBackward0autograd node[response_length])calculate_log_probs_and_entropyhandles empty[0, vocab]input_mask.sum() == orig_len→ skips filtering)get_valuesacceptsloss_masksparam without crash (signature compatibility)🤖 Generated with Claude Code