55 */
66#include " StackFlow.h"
77
8+ #include " EngineWrapper.hpp"
9+ #include " ax_sys_api.h"
10+
811#include < signal.h>
912#include < sys/stat.h>
1013#include < sys/types.h>
@@ -34,9 +37,6 @@ static std::string base_model_config_path_;
3437typedef std::function<void (const std::string &data, bool finish)> task_callback_t ;
3538
3639#include " sherpa-onnx/csrc/keyword-spotter.h"
37- #include " sherpa-onnx/csrc/parse-options.h"
38-
39- #include < onnxruntime_cxx_api.h>
4040#include " kaldi-native-fbank/csrc/online-feature.h"
4141
4242typedef struct mode_config_onnx {
@@ -61,6 +61,7 @@ class llm_task {
6161 bool enstream_ = false ;
6262 bool enwake_audio_ = true ;
6363 std::atomic_bool audio_flage_;
64+ static int ax_init_flage_;
6465 task_callback_t out_callback_;
6566 buffer_t *pcmdata;
6667 std::string wake_wav_file_;
@@ -72,8 +73,8 @@ class llm_task {
7273 std::unique_ptr<sherpa_onnx::OnlineStream> sherpa_stream_;
7374
7475 kws_config_onnx onnx_config_;
75- std::vector<float > onnx_cache_ ;
76- std::unique_ptr<Ort::Session> onnx_session_ ;
76+ std::vector<float > axera_cache_ ;
77+ std::unique_ptr<EngineWrapper> axera_session_ ;
7778 knf::FbankOptions fbank_opts_;
7879 std::unique_ptr<knf::OnlineFbank> fbank_;
7980 Ort::Env onnx_env_{ORT_LOGGING_LEVEL_WARNING, " kws" };
@@ -328,16 +329,20 @@ class llm_task {
328329 }
329330 std::string base_model = base_model_path_ + model_ + " /" ;
330331 SLOGI (" base_model %s" , base_model.c_str ());
331- std::string model_file = base_model + " kws.onnx " ;
332+ std::string model_file = base_model + " kws.axmodel " ;
332333
333334 if (config_body.contains (" wake_wav_file" ))
334335 wake_wav_file_ = config_body[" wake_wav_file" ];
335336 else if (file_body[" mode_param" ].contains (" wake_wav_file" ))
336337 wake_wav_file_ = file_body[" mode_param" ][" wake_wav_file" ];
337338
338- onnx_session_ = std::make_unique<Ort::Session>(onnx_env_, model_file.c_str (), session_options_);
339+ axera_session_ = std::make_unique<EngineWrapper>();
340+ if (0 != axera_session_->Init (model_file.c_str ())) {
341+ SLOGE (" Init axera model failed!" );
342+ return -5 ;
343+ }
339344
340- onnx_cache_ .assign (1 * 32 * 88 , 0 .0f );
345+ axera_cache_ .assign (1 * 32 * 88 , 0 .0f );
341346
342347 auto &mp = file_body[" mode_param" ];
343348 CONFIG_AUTO_SET_ONNX (mp, chunk_size);
@@ -414,37 +419,46 @@ class llm_task {
414419
415420 std::vector<float > run_inference (const std::vector<float > &audio_chunk_16k)
416421 {
417- std::vector<std::vector<float >> fbank_feats;
418- fbank_feats = compute_fbank_kaldi (audio_chunk_16k, onnx_config_.RESAMPLE_RATE , onnx_config_.FEAT_DIM );
419- if (fbank_feats.empty ()) {
422+ std::vector<std::vector<float >> fbank_feats =
423+ compute_fbank_kaldi (audio_chunk_16k, onnx_config_.RESAMPLE_RATE , onnx_config_.FEAT_DIM );
424+ if (fbank_feats.empty ()) return {};
425+
426+ constexpr int FIX_T = 32 ;
427+ const int FEAT_DIM = onnx_config_.FEAT_DIM ;
428+
429+ std::vector<float > mat_flattened;
430+ mat_flattened.resize (FIX_T * FEAT_DIM, 0 .0f );
431+
432+ const int T_in = static_cast <int >(fbank_feats.size ());
433+ const int T_copy = std::min (T_in, FIX_T);
434+
435+ for (int t = 0 ; t < T_copy; ++t) {
436+ if ((int )fbank_feats[t].size () < FEAT_DIM) continue ;
437+ std::memcpy (mat_flattened.data () + t * FEAT_DIM, fbank_feats[t].data (), sizeof (float ) * FEAT_DIM);
438+ }
439+
440+ axera_session_->SetInput (mat_flattened.data (), 0 );
441+
442+ axera_session_->SetInput (axera_cache_.data (), 1 );
443+
444+ int ret = axera_session_->Run ();
445+ if (ret) {
446+ SLOGE (" axera_session run failed!" );
420447 return {};
421448 }
422- int T = fbank_feats.size ();
423- std::vector<float > mat_flattened;
424- for (const auto &feat : fbank_feats) {
425- mat_flattened.insert (mat_flattened.end (), feat.begin (), feat.end ());
426- }
427- std::vector<int64_t > input_shape = {1 , static_cast <int64_t >(T), onnx_config_.FEAT_DIM };
428- std::vector<int64_t > cache_shape = {1 , 32 , 88 };
429- Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeDefault);
430- Ort::Value input_tensor = Ort::Value::CreateTensor<float >(
431- memory_info, mat_flattened.data (), mat_flattened.size (), input_shape.data (), input_shape.size ());
432- Ort::Value cache_tensor = Ort::Value::CreateTensor<float >(memory_info, onnx_cache_.data (), onnx_cache_.size (),
433- cache_shape.data (), cache_shape.size ());
434- const char *input_names[] = {" input" , " cache" };
435- const char *output_names[] = {" output" , " r_cache" };
436- std::vector<Ort::Value> inputs;
437- inputs.push_back (std::move (input_tensor));
438- inputs.push_back (std::move (cache_tensor));
439- auto output_tensors =
440- onnx_session_->Run (Ort::RunOptions{nullptr }, input_names, inputs.data (), 2 , output_names, 2 );
441- float *out_data = output_tensors[0 ].GetTensorMutableData <float >();
442- float *cache_out_data = output_tensors[1 ].GetTensorMutableData <float >();
443- std::vector<int64_t > out_shape = output_tensors[0 ].GetTensorTypeAndShapeInfo ().GetShape ();
444- size_t out_size = 1 ;
445- for (auto dim : out_shape) out_size *= dim;
446- std::vector<float > out_chunk (out_data, out_data + out_size);
447- std::copy (cache_out_data, cache_out_data + onnx_cache_.size (), onnx_cache_.begin ());
449+
450+ const float *out_ptr = reinterpret_cast <const float *>(axera_session_->GetOutputPtr (0 ));
451+ size_t out_size_f = axera_session_->GetOutputSize (0 ) / sizeof (float );
452+ std::vector<float > out_chunk (out_ptr, out_ptr + out_size_f);
453+
454+ const float *cache_ptr = reinterpret_cast <const float *>(axera_session_->GetOutputPtr (1 ));
455+ size_t cache_size_f = axera_session_->GetOutputSize (1 ) / sizeof (float );
456+ if (cache_size_f != axera_cache_.size ()) {
457+ SLOGE (" cache size mismatch: out=%zu, local=%zu" , cache_size_f, axera_cache_.size ());
458+ return out_chunk;
459+ }
460+ std::memcpy (axera_cache_.data (), cache_ptr, axera_cache_.size () * sizeof (float ));
461+
448462 return out_chunk;
449463 }
450464
@@ -531,16 +545,45 @@ class llm_task {
531545
532546 bool delete_model ()
533547 {
534- sherpa_spotter_.reset ();
535- sherpa_stream_.reset ();
536- onnx_session_. reset ();
537- fbank_.reset ();
548+ if (sherpa_spotter_) sherpa_spotter_.reset ();
549+ if (sherpa_stream_) sherpa_stream_.reset ();
550+ if (axera_session_) axera_session_-> Release ();
551+ if (fbank_) fbank_.reset ();
538552 return true ;
539553 }
540554
541555 llm_task (const std::string &workid) : audio_flage_(false )
542556 {
543557 pcmdata = buffer_create ();
558+ _ax_init ();
559+ }
560+
561+ void _ax_init ()
562+ {
563+ if (!ax_init_flage_) {
564+ int ret = AX_SYS_Init ();
565+ if (0 != ret) {
566+ fprintf (stderr, " AX_SYS_Init failed! ret = 0x%x\n " , ret);
567+ }
568+ AX_ENGINE_NPU_ATTR_T npu_attr;
569+ memset (&npu_attr, 0 , sizeof (npu_attr));
570+ ret = AX_ENGINE_Init (&npu_attr);
571+ if (0 != ret) {
572+ fprintf (stderr, " Init ax-engine failed{0x%8x}.\n " , ret);
573+ }
574+ }
575+ ax_init_flage_++;
576+ }
577+
578+ void _ax_deinit ()
579+ {
580+ if (ax_init_flage_ > 0 ) {
581+ --ax_init_flage_;
582+ if (!ax_init_flage_) {
583+ AX_ENGINE_Deinit ();
584+ AX_SYS_Deinit ();
585+ }
586+ }
544587 }
545588
546589 void start ()
@@ -555,9 +598,13 @@ class llm_task {
555598 {
556599 stop ();
557600 buffer_destroy (pcmdata);
601+ if (axera_session_) axera_session_->Release ();
602+ _ax_deinit ();
558603 }
559604};
560605
606+ int llm_task::ax_init_flage_ = 0 ;
607+
561608class llm_kws : public StackFlow {
562609private:
563610 enum { EVENT_TRIGGER = EVENT_EXPORT + 1 };
0 commit comments