|
15 | 15 | namespace gcpp { |
16 | 16 |
|
17 | 17 | // Passed to HWY_VISIT_TARGETS; declares for one target. |
18 | | -#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \ |
19 | | - namespace NAMESPACE { \ |
20 | | - void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \ |
21 | | - size_t layer_idx, const LayerWeightsPtrs& layer, \ |
22 | | - AttentionActivationsPtrs& activations, QBatch& qbatch, \ |
23 | | - MatMulEnv& env, int flags); \ |
24 | | - void TransposeStridedQueries(hwy::Span<float*> queries, int qkv_dim, \ |
25 | | - hwy::Span<float> transposed_queries); \ |
26 | | - void LocalAttentionForAllHeadsTokensAndBatch( \ |
27 | | - AttentionImpl attention_impl, const size_t num_tokens, \ |
28 | | - const size_t layer_idx, const LayerWeightsPtrs& layer, \ |
29 | | - AttentionActivationsPtrs& activations, QBatch& qbatch, \ |
30 | | - ThreadingContext& ctx); \ |
31 | | - \ |
32 | | - template <typename OutT> \ |
33 | | - std::tuple<std::vector<OutT, hwy::AlignedAllocator<OutT>>, \ |
34 | | - std::vector<OutT*>, AlignedFloatVector> \ |
35 | | - TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span<float*> queries_ptrs, \ |
36 | | - int qkv_dim, size_t group_size); \ |
37 | | - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ |
| 18 | +#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \ |
| 19 | + namespace NAMESPACE { \ |
| 20 | + void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \ |
| 21 | + size_t layer_idx, const LayerWeightsPtrs& layer, \ |
| 22 | + AttentionActivationsPtrs& activations, QBatch& qbatch, \ |
| 23 | + MatMulEnv& env, int flags); \ |
| 24 | + void TransposeStridedQueries(hwy::Span<float*> queries, int qkv_dim, \ |
| 25 | + hwy::Span<float> transposed_queries); \ |
| 26 | + void LocalAttentionForAllHeadsTokensAndBatch( \ |
| 27 | + AttentionImpl attention_impl, const size_t num_tokens, \ |
| 28 | + const size_t layer_idx, const LayerWeightsPtrs& layer, \ |
| 29 | + AttentionActivationsPtrs& activations, QBatch& qbatch, \ |
| 30 | + ThreadingContext& ctx); \ |
| 31 | + \ |
| 32 | + template <typename OutT> \ |
| 33 | + std::tuple<std::vector<OutT, hwy::AlignedAllocator<OutT>>, \ |
| 34 | + std::vector<OutT*>, AlignedFloatVector> \ |
| 35 | + TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span<float*> queries_ptrs, \ |
| 36 | + int qkv_dim, size_t group_size); \ |
| 37 | + \ |
| 38 | + std::tuple<std::vector<int16_t, hwy::AlignedAllocator<int16_t>>, \ |
| 39 | + std::vector<int16_t*>, AlignedFloatVector> \ |
| 40 | + TransposeQueriesToGroupsOfNInt16(hwy::Span<float*> queries_ptrs, \ |
| 41 | + int qkv_dim, size_t group_size); \ |
| 42 | + \ |
| 43 | + std::tuple<std::vector<BF16, hwy::AlignedAllocator<BF16>>, \ |
| 44 | + std::vector<BF16*>, AlignedFloatVector> \ |
| 45 | + TransposeQueriesToGroupsOfNBF16(hwy::Span<float*> queries_ptrs, int qkv_dim, \ |
| 46 | + size_t group_size); \ |
| 47 | + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ |
38 | 48 | } // namespace NAMESPACE |
39 | 49 |
|
40 | 50 | // Function declarations for each SIMD target. Allows direct call from the |
|
0 commit comments