Skip to content

feat: filter logits by loss_mask before log_probs/entropy computation#1905

Open
Taosheng-ty wants to merge 1 commit into
THUDM:mainfrom
Taosheng-ty:feat/filter-logits-by-loss-mask
Open

feat: filter logits by loss_mask before log_probs/entropy computation#1905
Taosheng-ty wants to merge 1 commit into
THUDM:mainfrom
Taosheng-ty:feat/filter-logits-by-loss-mask

Conversation

@Taosheng-ty
Copy link
Copy Markdown

Summary

  • Filter logits to only loss_mask == 1 positions before the expensive vocab-parallel softmax in get_log_probs_and_entropy, then pad output back to original length with zeros
  • For multi-turn agent rollouts where tool-result tokens are ~97% of response, this reduces softmax compute by ~30x and prevents OOM on long samples
  • Only active when cp_size == 1 (falls through to unfiltered path for context parallelism > 1)

Motivation

In agentic workloads (tool-use, multi-turn), the "response" includes large tool-result tokens that are masked out in the loss. Without this change, all those positions still go through the full vocab-parallel softmax (all-reduce MAX + SUM + SUM), wasting memory and compute only to be multiplied by zero downstream.

Typical savings for our workload:

  • 97% masked tokens → ~30x reduction in softmax compute
  • Prevents OOM on long multi-turn samples with large tool outputs
  • Communication in vocab-parallel all-reduces drops proportionally (e.g., 4096 → 120 positions)

Changes

  • loss.py: Add loss_masks param to get_log_probs_and_entropy; filter before compute, pad back after
  • loss.py: Add loss_masks param to get_values (accepts but ignores, for signature compatibility with forward_only)
  • loss.py: Pass loss_masks in policy_loss_function and sft_loss_function callers
  • model.py: Pass loss_masks in forward_only partial

Design

Before: logits [response_len, vocab/tp] → softmax ALL → mask after
After:  filter by loss_mask → logits [num_active, vocab/tp] → softmax → pad back with zeros

The output shape is unchanged ([response_length] per sample), so all downstream code (sum_of_sample_mean, advantages, old_log_probs) works without modification. Zeros at masked positions are fine because sum_of_sample_mean already multiplies by loss_masks.

Limitations

  • Only active when cp_size == 1. With context parallelism > 1, the response is split across ranks in a zigzag pattern, and mask alignment requires additional offset logic. This can be added in a follow-up.

Test plan

  • Gradient correctness verified: torch.zeros + indexed assignment produces correct IndexPutBackward0 autograd node
  • Output shape matches original (padded back to [response_length])
  • All-positions-masked edge case: calculate_log_probs_and_entropy handles empty [0, vocab] input
  • No-op when all positions are unmasked (_mask.sum() == orig_len → skips filtering)
  • get_values accepts loss_masks param without crash (signature compatibility)
  • End-to-end training run with multi-turn agent data

🤖 Generated with Claude Code

For multi-turn agent rollouts where tool-result tokens dominate the
response (often >90%), computing log-probs and entropy for all positions
wastes memory and compute — those masked positions contribute zeros to
the loss anyway.

This adds a loss_masks parameter to get_log_probs_and_entropy. When
provided (and cp_size == 1), only positions where mask == 1 go through
the expensive vocab-parallel softmax. Outputs are padded back to the
original response length with zeros so all downstream code (advantages,
sum_of_sample_mean, etc.) works unchanged.

Typical savings for agentic workloads:
  - 97% masked tokens → ~30x reduction in softmax compute
  - Prevents OOM on long multi-turn samples with large tool outputs
  - Communication in vocab-parallel all-reduces drops proportionally

Limitations:
  - Only active when cp_size == 1 (falls through to unfiltered path
    for context parallelism > 1)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@Taosheng-ty Taosheng-ty force-pushed the feat/filter-logits-by-loss-mask branch from d9cde49 to 3643dea Compare May 13, 2026 01:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants