Skip to content

fix: use integer division (//) in GQA reshape operations#643

Open
tomohiro86 wants to merge 1 commit intogoogle-deepmind:mainfrom
tomohiro86:fix/gqa-use-integer-division
Open

fix: use integer division (//) in GQA reshape operations#643
tomohiro86 wants to merge 1 commit intogoogle-deepmind:mainfrom
tomohiro86:fix/gqa-use-integer-division

Conversation

@tomohiro86
Copy link
Copy Markdown

Summary

Fixes #641.

In GQA (Grouped Query Attention) reshape operations, int(kg / self.num_kv_heads) was using float division (/) instead of integer division (//). When kg is not evenly divisible by num_kv_heads, Python silently truncates the result via int(), producing an incorrect tensor shape with no error.

  • Replace int(kg / self.num_kv_heads)kg // self.num_kv_heads in all 8 occurrences across 4 files

Changed files

File Occurrences fixed
gemma/gm/nn/_modules.py 2
gemma/gm/nn/gemma3n/_modules.py 2
gemma/gm/nn/gemma4/_modules.py 2
gemma/research/t5gemma/modules.py 2

Test plan

  • Existing unit tests pass with standard configs (where num_query_heads is a multiple of num_kv_heads)
  • Behavior is identical when divisible — // and int(/) produce the same result in that case
  • With non-divisible configs, the new code raises a clear ValueError from JAX/NumPy instead of silently truncating

🤖 Generated with Claude Code

…eshape

Replaces `int(kg / self.num_kv_heads)` with `kg // self.num_kv_heads` in
all GQA reshape operations across 4 modules. Using float division with
int() silently truncates when kg is not evenly divisible by num_kv_heads,
producing an incorrect tensor shape with no error. Fixes google-deepmind#641.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug: float division instead of integer division in GQA reshape causes silent shape truncation

1 participant