Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions gemma/tiled_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,27 @@
namespace gcpp {

// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
void TransposeStridedQueries(hwy::Span<float*> queries, int qkv_dim, \
hwy::Span<float> transposed_queries); \
void LocalAttentionForAllHeadsTokensAndBatch( \
AttentionImpl attention_impl, const size_t num_tokens, \
const size_t layer_idx, const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
\
template <typename OutT> \
std::tuple<std::vector<OutT, hwy::AlignedAllocator<OutT>>, \
std::vector<OutT*>, AlignedFloatVector> \
TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span<float*> queries_ptrs, \
int qkv_dim, size_t group_size); \
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
#define GEMMA_DECL_TILED_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
void TiledAttention(AttentionImpl attention_impl, size_t num_tokens, \
size_t layer_idx, const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
MatMulEnv& env, int flags); \
void TransposeStridedQueries(hwy::Span<float*> queries, int qkv_dim, \
hwy::Span<float> transposed_queries); \
void LocalAttentionForAllHeadsTokensAndBatch( \
AttentionImpl attention_impl, const size_t num_tokens, \
const size_t layer_idx, const LayerWeightsPtrs& layer, \
AttentionActivationsPtrs& activations, QBatch& qbatch, \
ThreadingContext& ctx); \
\
template <typename OutT> \
std::tuple<std::vector<OutT, hwy::AlignedAllocator<OutT>>, \
std::vector<OutT*>, AlignedFloatVector> \
TransposeQueriesToGroupsOfNBF16orInt16(hwy::Span<float*> queries_ptrs, \
int qkv_dim, size_t group_size); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

// Function declarations for each SIMD target. Allows direct call from the
Expand Down
Loading