@@ -341,46 +341,15 @@ void extTokenSoftmaxReduceVInference(const at::Tensor& logics,
341341 b_start_loc, b_seq_len, max_input_len, other_kv_index);
342342}
343343
344- // void extTokenDecodeAttentionInference(const at::Tensor& q, const at::Tensor& k,
345- // const at::Tensor& v, at::Tensor& out,
346- // const at::Tensor& b_loc,
347- // const at::Tensor& b_start_loc,
348- // const at::Tensor& b_seq_len,
349- // int max_input_len, int other_kv_index) {
350- // callDiopi(diopiTokenDecodeAttentionInference, out, q, k, v, b_loc, b_start_loc,
351- // b_seq_len, max_input_len, other_kv_index);
352- // }
353-
354- // void extTokenDecodeAttentionInferenceBatchOne(const at::Tensor& q, const at::Tensor& k,
355- // const at::Tensor& v, at::Tensor& out,
356- // const at::Tensor& b_loc,
357- // const at::Tensor& b_start_loc,
358- // const at::Tensor& b_seq_len,
359- // int max_input_len, int other_kv_index) {
360- // callDiopi(diopiTokenDecodeAttentionInferenceBatchOne, out, q, k, v, b_loc, b_start_loc,
361- // b_seq_len, max_input_len, other_kv_index);
362- // }
363-
364- // void extIncreFlashAttention(const at::Tensor& q, const at::Tensor& k,
365- // const at::Tensor& v, at::Tensor& out,
366- // const int head, const char* layout,
367- // const c10::optional<at::Tensor>& padding_mask = {},
368- // const c10::optional<at::Tensor>& atten_mask = {},
369- // const OptionalIntArray& actual_seq_lengths = {},
370- // int64_t num_heads = 1, double scale_value = 1.0,
371- // const std::string& input_layout = "BSH", int64_t num_key_value_heads = 0) {
372- // callDiopi(diopiIncreFlashAttention, out, q, k, v, padding_mask, atten_mask,
373- // actual_seq_lengths, num_heads, scale_value, input_layout.c_str(), num_key_value_heads);
374- // }
375-
376344void extPromptFlashAttention (at::Tensor& out, const at::Tensor& q,
377345 const at::Tensor& k, const at::Tensor& v,
378346 const at::Tensor& atten_mask,
379347 const at::IntArrayRef& actual_seq_lengths,
380- int64_t max_input_len, int64_t num_heads,
348+ int64_t max_input_len, int64_t num_heads,
381349 int64_t num_key_value_heads, int64_t dim) {
382350 callDiopi (diopiPromptFlashAttention, out, q, k, v, atten_mask,
383- actual_seq_lengths, max_input_len, num_heads, num_key_value_heads, dim);
351+ actual_seq_lengths, max_input_len, num_heads, num_key_value_heads,
352+ dim);
384353}
385354
386355void extContextAttentionInference (const at::Tensor& q, const at::Tensor& k,
@@ -403,34 +372,39 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty,
403372}
404373
405374void extApplyPenaltyV2 (at::Tensor& logits, const at::Tensor& presence_penalty,
406- const at::Tensor& frequency_penalty,
407- const at::Tensor& repetition_penalty,
408- const at::Tensor& p_token_ids,
409- const at::Tensor& p_token_counts) {
410- callDiopi (diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty, repetition_penalty,
411- p_token_ids, p_token_counts);
375+ const at::Tensor& frequency_penalty,
376+ const at::Tensor& repetition_penalty,
377+ const at::Tensor& p_token_ids,
378+ const at::Tensor& p_token_counts) {
379+ callDiopi (diopiApplyPenaltyV2, logits, presence_penalty, frequency_penalty,
380+ repetition_penalty, p_token_ids, p_token_counts);
412381}
413382
414- void extPagedAttention (at::Tensor& out, const at::Tensor& q, const at::Tensor& k, const at::Tensor& v,
415- const at::IntArrayRef& actual_seq_lengths,
416- int64_t numHeads, int64_t numKeyValueHeads, int64_t dim,
417- const at::Tensor& block_table,
418- int64_t block_size) {
419- callDiopi (diopiPagedAttention, out, q, k, v, actual_seq_lengths,
420- numHeads, numKeyValueHeads, dim,
421- block_table, block_size);
383+ void extPagedAttention (at::Tensor& out, const at::Tensor& q,
384+ const at::Tensor& k, const at::Tensor& v,
385+ const c10::optional<at::Tensor>& atten_mask = {},
386+ const at::IntArrayRef& actual_seq_lengths = {},
387+ int64_t num_heads = 1 , int64_t num_kv_heads = 1 ,
388+ int64_t dim = 1 ,
389+ const c10::optional<at::Tensor>& block_table = {},
390+ int64_t block_size = 1 ) {
391+ callDiopi (diopiPagedAttention, out, q, k, v, atten_mask, actual_seq_lengths,
392+ num_heads, num_kv_heads, dim, block_table, block_size);
422393}
423394
424- void extRotaryEmbeddingV2 (at::Tensor& query, at::Tensor& key, const at::Tensor& cos, const at::Tensor& sin, int64_t dim) {
395+ void extRotaryEmbeddingV2 (at::Tensor& query, at::Tensor& key,
396+ const at::Tensor& cos, const at::Tensor& sin,
397+ int64_t dim) {
425398 callDiopi (diopiRotaryEmbeddingV2, query, key, cos, sin, dim);
426399}
427400
428401void extMatmulAllReduce (at::Tensor& out, const at::Tensor& x1,
429- const at::Tensor& x2, const c10::optional<at::Tensor>& bias,
402+ const at::Tensor& x2,
403+ const c10::optional<at::Tensor>& bias,
430404 const char * group, const char * reduce_op,
431405 int64_t comm_turn, int64_t stream_mode) {
432- callDiopi (diopiMatmulAllReduce, out, x1, x2,
433- bias, group, reduce_op, comm_turn, stream_mode);
406+ callDiopi (diopiMatmulAllReduce, out, x1, x2, bias, group, reduce_op,
407+ comm_turn, stream_mode);
434408}
435409
436410// 判断是否有对应的 diopi 实现:
@@ -501,18 +475,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
501475 m.def (" token_softmax_reducev_inference" , &extTokenSoftmaxReduceVInference,
502476 " deeplink ext_token_softmax_reducev_inference" );
503477 }
504- // if (&diopiTokenDecodeAttentionInference != nullptr) {
505- // m.def("token_decode_attention_inference", &extTokenDecodeAttentionInference,
506- // "deeplink token_decode_attention_inference");
507- // }
508- // if (&diopiTokenDecodeAttentionInferenceBatchOne != nullptr) {
509- // m.def("token_decode_attention_inference_batch_one", &extTokenDecodeAttentionInferenceBatchOne,
510- // "deeplink token_decode_attention_inference");
511- // }
512- // if (&diopiIncreFlashAttention != nullptr) {
513- // m.def("incre_flash_attention", &extIncreFlashAttention,
514- // "deeplink incre_flash_attention");
515- // }
516478 if (&diopiPromptFlashAttention != nullptr ) {
517479 m.def (" prompt_flash_attention" , &extPromptFlashAttention,
518480 " deeplink ext_prompt_flash_attention" );
@@ -540,15 +502,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
540502 " deeplink ext_paged_attention" );
541503 }
542504 if (&diopiRotaryEmbeddingV2 != nullptr ) {
543- m.def (" rotary_embedding_v2" , &extRotaryEmbeddingV2, " deeplink extRotaryEmbeddingV2" );
505+ m.def (" rotary_embedding_v2" , &extRotaryEmbeddingV2,
506+ " deeplink extRotaryEmbeddingV2" );
544507 }
545508 if (&diopiMatmulAllReduce != nullptr ) {
546509 m.def (" matmul_all_reduce" , &extMatmulAllReduce,
547- " deeplink ext_matmul_all_reduce" ,
548- py::arg (" out" ), py::arg (" x1" ),
549- py::arg (" x2" ), py::arg (" bias" ),
550- py::arg (" group" ), py::arg (" reduce_op" ) = " sum" ,
551- py::arg (" comm_turn" ) = 0 , py::arg (" stream_mode" ) = 1 );
510+ " deeplink ext_matmul_all_reduce" , py::arg (" out" ), py::arg (" x1" ),
511+ py::arg (" x2" ), py::arg (" bias" ), py::arg (" group" ),
512+ py::arg (" reduce_op" ) = " sum" , py::arg (" comm_turn" ) = 0 ,
513+ py::arg (" stream_mode" ) = 1 );
552514 }
553515}
554516
0 commit comments