Skip to content

Commit d5685d4

Browse files
author
LittleMouse
committed
[update] kws support axmodel
1 parent ea7ddd0 commit d5685d4

File tree

7 files changed

+1225
-44
lines changed

7 files changed

+1225
-44
lines changed

projects/llm_framework/main_kws/SConstruct

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ Import('env')
44
with open(env['PROJECT_TOOL_S']) as f:
55
exec(f.read())
66

7-
SRCS = Glob('src/*.c*')
7+
SRCS = append_srcs_dir(ADir('src'))
88
INCLUDE = [ADir('include'), ADir('.')]
99
PRIVATE_INCLUDE = []
10-
REQUIREMENTS = ['pthread', 'dl', 'utilities', 'eventpp', 'StackFlow', 'single_header_libs']
10+
REQUIREMENTS = ['pthread', 'dl', 'utilities', 'ax_msp', 'eventpp', 'StackFlow', 'single_header_libs']
1111
STATIC_LIB = []
1212
DYNAMIC_LIB = []
1313
DEFINITIONS = []
@@ -22,14 +22,16 @@ DEFINITIONS += ['-std=c++17']
2222
LDFLAGS+=['-Wl,-rpath=/opt/m5stack/lib', '-Wl,-rpath=/usr/local/m5stack/lib', '-Wl,-rpath=/usr/local/m5stack/lib/gcc-10.3', '-Wl,-rpath=/opt/lib', '-Wl,-rpath=/opt/usr/lib', '-Wl,-rpath=./']
2323
LINK_SEARCH_PATH += [ADir('../static_lib')]
2424

25-
INCLUDE += [ADir('../static_lib/include/sherpa'),
25+
INCLUDE += [ADir('src/runner'),
26+
ADir('../static_lib/include/sherpa'),
2627
ADir('../static_lib/include/sherpa/fbank'),
2728
ADir('../static_lib/include/sherpa/sherpa-onnx'),
2829
ADir('../static_lib/include/sherpa/sherpa-onnx/onnxruntime-src'),
2930
ADir('../static_lib/include/sherpa/sherpa-onnx/openfst-src')
3031
]
3132

3233
LINK_SEARCH_PATH += [ADir('../static_lib/sherpa/onnx')]
34+
REQUIREMENTS += ['ax_engine', 'ax_interpreter', 'ax_sys']
3335
REQUIREMENTS += ['onnxruntime']
3436

3537
LDFLAGS += ['-l:libsherpa-onnx-core.a', '-l:libkaldi-native-fbank-core.a','-l:libkissfft-float.a',

projects/llm_framework/main_kws/src/main.cpp

Lines changed: 88 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
*/
66
#include "StackFlow.h"
77

8+
#include "EngineWrapper.hpp"
9+
#include "ax_sys_api.h"
10+
811
#include <signal.h>
912
#include <sys/stat.h>
1013
#include <sys/types.h>
@@ -34,9 +37,6 @@ static std::string base_model_config_path_;
3437
typedef std::function<void(const std::string &data, bool finish)> task_callback_t;
3538

3639
#include "sherpa-onnx/csrc/keyword-spotter.h"
37-
#include "sherpa-onnx/csrc/parse-options.h"
38-
39-
#include <onnxruntime_cxx_api.h>
4040
#include "kaldi-native-fbank/csrc/online-feature.h"
4141

4242
typedef struct mode_config_onnx {
@@ -61,6 +61,7 @@ class llm_task {
6161
bool enstream_ = false;
6262
bool enwake_audio_ = true;
6363
std::atomic_bool audio_flage_;
64+
static int ax_init_flage_;
6465
task_callback_t out_callback_;
6566
buffer_t *pcmdata;
6667
std::string wake_wav_file_;
@@ -72,8 +73,8 @@ class llm_task {
7273
std::unique_ptr<sherpa_onnx::OnlineStream> sherpa_stream_;
7374

7475
kws_config_onnx onnx_config_;
75-
std::vector<float> onnx_cache_;
76-
std::unique_ptr<Ort::Session> onnx_session_;
76+
std::vector<float> axera_cache_;
77+
std::unique_ptr<EngineWrapper> axera_session_;
7778
knf::FbankOptions fbank_opts_;
7879
std::unique_ptr<knf::OnlineFbank> fbank_;
7980
Ort::Env onnx_env_{ORT_LOGGING_LEVEL_WARNING, "kws"};
@@ -328,16 +329,20 @@ class llm_task {
328329
}
329330
std::string base_model = base_model_path_ + model_ + "/";
330331
SLOGI("base_model %s", base_model.c_str());
331-
std::string model_file = base_model + "kws.onnx";
332+
std::string model_file = base_model + "kws.axmodel";
332333

333334
if (config_body.contains("wake_wav_file"))
334335
wake_wav_file_ = config_body["wake_wav_file"];
335336
else if (file_body["mode_param"].contains("wake_wav_file"))
336337
wake_wav_file_ = file_body["mode_param"]["wake_wav_file"];
337338

338-
onnx_session_ = std::make_unique<Ort::Session>(onnx_env_, model_file.c_str(), session_options_);
339+
axera_session_ = std::make_unique<EngineWrapper>();
340+
if (0 != axera_session_->Init(model_file.c_str())) {
341+
SLOGE("Init axera model failed!");
342+
return -5;
343+
}
339344

340-
onnx_cache_.assign(1 * 32 * 88, 0.0f);
345+
axera_cache_.assign(1 * 32 * 88, 0.0f);
341346

342347
auto &mp = file_body["mode_param"];
343348
CONFIG_AUTO_SET_ONNX(mp, chunk_size);
@@ -414,37 +419,46 @@ class llm_task {
414419

415420
std::vector<float> run_inference(const std::vector<float> &audio_chunk_16k)
416421
{
417-
std::vector<std::vector<float>> fbank_feats;
418-
fbank_feats = compute_fbank_kaldi(audio_chunk_16k, onnx_config_.RESAMPLE_RATE, onnx_config_.FEAT_DIM);
419-
if (fbank_feats.empty()) {
422+
std::vector<std::vector<float>> fbank_feats =
423+
compute_fbank_kaldi(audio_chunk_16k, onnx_config_.RESAMPLE_RATE, onnx_config_.FEAT_DIM);
424+
if (fbank_feats.empty()) return {};
425+
426+
constexpr int FIX_T = 32;
427+
const int FEAT_DIM = onnx_config_.FEAT_DIM;
428+
429+
std::vector<float> mat_flattened;
430+
mat_flattened.resize(FIX_T * FEAT_DIM, 0.0f);
431+
432+
const int T_in = static_cast<int>(fbank_feats.size());
433+
const int T_copy = std::min(T_in, FIX_T);
434+
435+
for (int t = 0; t < T_copy; ++t) {
436+
if ((int)fbank_feats[t].size() < FEAT_DIM) continue;
437+
std::memcpy(mat_flattened.data() + t * FEAT_DIM, fbank_feats[t].data(), sizeof(float) * FEAT_DIM);
438+
}
439+
440+
axera_session_->SetInput(mat_flattened.data(), 0);
441+
442+
axera_session_->SetInput(axera_cache_.data(), 1);
443+
444+
int ret = axera_session_->Run();
445+
if (ret) {
446+
SLOGE("axera_session run failed!");
420447
return {};
421448
}
422-
int T = fbank_feats.size();
423-
std::vector<float> mat_flattened;
424-
for (const auto &feat : fbank_feats) {
425-
mat_flattened.insert(mat_flattened.end(), feat.begin(), feat.end());
426-
}
427-
std::vector<int64_t> input_shape = {1, static_cast<int64_t>(T), onnx_config_.FEAT_DIM};
428-
std::vector<int64_t> cache_shape = {1, 32, 88};
429-
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
430-
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
431-
memory_info, mat_flattened.data(), mat_flattened.size(), input_shape.data(), input_shape.size());
432-
Ort::Value cache_tensor = Ort::Value::CreateTensor<float>(memory_info, onnx_cache_.data(), onnx_cache_.size(),
433-
cache_shape.data(), cache_shape.size());
434-
const char *input_names[] = {"input", "cache"};
435-
const char *output_names[] = {"output", "r_cache"};
436-
std::vector<Ort::Value> inputs;
437-
inputs.push_back(std::move(input_tensor));
438-
inputs.push_back(std::move(cache_tensor));
439-
auto output_tensors =
440-
onnx_session_->Run(Ort::RunOptions{nullptr}, input_names, inputs.data(), 2, output_names, 2);
441-
float *out_data = output_tensors[0].GetTensorMutableData<float>();
442-
float *cache_out_data = output_tensors[1].GetTensorMutableData<float>();
443-
std::vector<int64_t> out_shape = output_tensors[0].GetTensorTypeAndShapeInfo().GetShape();
444-
size_t out_size = 1;
445-
for (auto dim : out_shape) out_size *= dim;
446-
std::vector<float> out_chunk(out_data, out_data + out_size);
447-
std::copy(cache_out_data, cache_out_data + onnx_cache_.size(), onnx_cache_.begin());
449+
450+
const float *out_ptr = reinterpret_cast<const float *>(axera_session_->GetOutputPtr(0));
451+
size_t out_size_f = axera_session_->GetOutputSize(0) / sizeof(float);
452+
std::vector<float> out_chunk(out_ptr, out_ptr + out_size_f);
453+
454+
const float *cache_ptr = reinterpret_cast<const float *>(axera_session_->GetOutputPtr(1));
455+
size_t cache_size_f = axera_session_->GetOutputSize(1) / sizeof(float);
456+
if (cache_size_f != axera_cache_.size()) {
457+
SLOGE("cache size mismatch: out=%zu, local=%zu", cache_size_f, axera_cache_.size());
458+
return out_chunk;
459+
}
460+
std::memcpy(axera_cache_.data(), cache_ptr, axera_cache_.size() * sizeof(float));
461+
448462
return out_chunk;
449463
}
450464

@@ -531,16 +545,45 @@ class llm_task {
531545

532546
bool delete_model()
533547
{
534-
sherpa_spotter_.reset();
535-
sherpa_stream_.reset();
536-
onnx_session_.reset();
537-
fbank_.reset();
548+
if (sherpa_spotter_) sherpa_spotter_.reset();
549+
if (sherpa_stream_) sherpa_stream_.reset();
550+
if (axera_session_) axera_session_->Release();
551+
if (fbank_) fbank_.reset();
538552
return true;
539553
}
540554

541555
llm_task(const std::string &workid) : audio_flage_(false)
542556
{
543557
pcmdata = buffer_create();
558+
_ax_init();
559+
}
560+
561+
void _ax_init()
562+
{
563+
if (!ax_init_flage_) {
564+
int ret = AX_SYS_Init();
565+
if (0 != ret) {
566+
fprintf(stderr, "AX_SYS_Init failed! ret = 0x%x\n", ret);
567+
}
568+
AX_ENGINE_NPU_ATTR_T npu_attr;
569+
memset(&npu_attr, 0, sizeof(npu_attr));
570+
ret = AX_ENGINE_Init(&npu_attr);
571+
if (0 != ret) {
572+
fprintf(stderr, "Init ax-engine failed{0x%8x}.\n", ret);
573+
}
574+
}
575+
ax_init_flage_++;
576+
}
577+
578+
void _ax_deinit()
579+
{
580+
if (ax_init_flage_ > 0) {
581+
--ax_init_flage_;
582+
if (!ax_init_flage_) {
583+
AX_ENGINE_Deinit();
584+
AX_SYS_Deinit();
585+
}
586+
}
544587
}
545588

546589
void start()
@@ -555,9 +598,13 @@ class llm_task {
555598
{
556599
stop();
557600
buffer_destroy(pcmdata);
601+
if (axera_session_) axera_session_->Release();
602+
_ax_deinit();
558603
}
559604
};
560605

606+
int llm_task::ax_init_flage_ = 0;
607+
561608
class llm_kws : public StackFlow {
562609
private:
563610
enum { EVENT_TRIGGER = EVENT_EXPORT + 1 };

0 commit comments

Comments
 (0)