Skip to content

Commit 3eb6060

Browse files
Krzysztof Rymskicopybara-github
authored andcommitted
Expose transposition functions in the tiled attention
PiperOrigin-RevId: 888724073
1 parent f56d18d commit 3eb6060

2 files changed

Lines changed: 46 additions & 20 deletions

File tree

gemma/tiled_attention.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,22 @@ TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span<float*> queries_ptrs,
502502
std::move(q_scales));
503503
}
504504

505+
std::tuple<std::vector<int16_t, hwy::AlignedAllocator<int16_t>>,
506+
std::vector<int16_t*>, AlignedFloatVector>
507+
TransposeQueriesToGroupsOfNInt16(hwy::Span<float*> queries_ptrs, int qkv_dim,
508+
size_t group_size) {
509+
return TransposeQueriesToGroupsOfNBF16orInt16<int16_t>(queries_ptrs, qkv_dim,
510+
group_size);
511+
}
512+
513+
std::tuple<std::vector<BF16, hwy::AlignedAllocator<BF16>>, std::vector<BF16*>,
514+
AlignedFloatVector>
515+
TransposeQueriesToGroupsOfNBF16(hwy::Span<float*> queries_ptrs, int qkv_dim,
516+
size_t group_size) {
517+
return TransposeQueriesToGroupsOfNBF16orInt16<BF16>(queries_ptrs, qkv_dim,
518+
group_size);
519+
}
520+
505521
std::pair<AlignedBF16Vector, std::vector<BF16*>>
506522
TransposeTransposedQueriesAndPackIntoBF16(hwy::Span<float*> queries_ptrs,
507523
int qkv_dim, int num_queries) {

gemma/tiled_attention.h

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,36 @@
1515
namespace gcpp {
1616

1717
// Passed to HWY_VISIT_TARGETS; declares for one target.
18-
#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \
19-
namespace NAMESPACE { \
20-
void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \
21-
size_t layer_idx, const LayerWeightsPtrs& layer, \
22-
AttentionActivationsPtrs& activations, QBatch& qbatch, \
23-
MatMulEnv& env, int flags); \
24-
void TransposeStridedQueries(hwy::Span<float*> queries, int qkv_dim, \
25-
hwy::Span<float> transposed_queries); \
26-
void LocalAttentionForAllHeadsTokensAndBatch( \
27-
AttentionImpl attention_impl, const size_t num_tokens, \
28-
const size_t layer_idx, const LayerWeightsPtrs& layer, \
29-
AttentionActivationsPtrs& activations, QBatch& qbatch, \
30-
ThreadingContext& ctx); \
31-
\
32-
template <typename OutT> \
33-
std::tuple<std::vector<OutT, hwy::AlignedAllocator<OutT>>, \
34-
std::vector<OutT*>, AlignedFloatVector> \
35-
TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span<float*> queries_ptrs, \
36-
int qkv_dim, size_t group_size); \
37-
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
18+
#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \
19+
namespace NAMESPACE { \
20+
void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \
21+
size_t layer_idx, const LayerWeightsPtrs& layer, \
22+
AttentionActivationsPtrs& activations, QBatch& qbatch, \
23+
MatMulEnv& env, int flags); \
24+
void TransposeStridedQueries(hwy::Span<float*> queries, int qkv_dim, \
25+
hwy::Span<float> transposed_queries); \
26+
void LocalAttentionForAllHeadsTokensAndBatch( \
27+
AttentionImpl attention_impl, const size_t num_tokens, \
28+
const size_t layer_idx, const LayerWeightsPtrs& layer, \
29+
AttentionActivationsPtrs& activations, QBatch& qbatch, \
30+
ThreadingContext& ctx); \
31+
\
32+
template <typename OutT> \
33+
std::tuple<std::vector<OutT, hwy::AlignedAllocator<OutT>>, \
34+
std::vector<OutT*>, AlignedFloatVector> \
35+
TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span<float*> queries_ptrs, \
36+
int qkv_dim, size_t group_size); \
37+
\
38+
std::tuple<std::vector<int16_t, hwy::AlignedAllocator<int16_t>>, \
39+
std::vector<int16_t*>, AlignedFloatVector> \
40+
TransposeQueriesToGroupsOfNInt16(hwy::Span<float*> queries_ptrs, \
41+
int qkv_dim, size_t group_size); \
42+
\
43+
std::tuple<std::vector<BF16, hwy::AlignedAllocator<BF16>>, \
44+
std::vector<BF16*>, AlignedFloatVector> \
45+
TransposeQueriesToGroupsOfNBF16(hwy::Span<float*> queries_ptrs, int qkv_dim, \
46+
size_t group_size); \
47+
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
3848
} // namespace NAMESPACE
3949

4050
// Function declarations for each SIMD target. Allows direct call from the

0 commit comments

Comments
 (0)