diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 459b062fd67..5422fb15b71 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -20,6 +20,7 @@ namespace extension { namespace ET_MODULE_NAMESPACE { using ET_MERGED_DATA_MAP_NAMESPACE::MergedDataMap; +using ET_RUNTIME_NAMESPACE::Kernel; using ET_RUNTIME_NAMESPACE::MethodMeta; using ET_RUNTIME_NAMESPACE::Program; @@ -406,7 +407,8 @@ runtime::Error Module::load_method( const std::string& method_name, runtime::HierarchicalAllocator* planned_memory, torch::executor::EventTracer* event_tracer, - const LoadBackendOptionsMap* backend_options) { + const LoadBackendOptionsMap* backend_options, + std::vector kernel_registry) { if (!is_method_loaded(method_name)) { ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -446,12 +448,16 @@ runtime::Error Module::load_method( method_holder.memory_manager = std::make_unique( memory_allocator_.get(), planned_memory, temp_allocator_.get()); + method_holder.kernel_registry = std::move(kernel_registry); auto res_method = program_->load_method( method_name.c_str(), method_holder.memory_manager.get(), event_tracer ? event_tracer : this->event_tracer(), merged_data_map_.get(), - effective_backend_options); + effective_backend_options, + runtime::Span( + method_holder.kernel_registry.data(), + method_holder.kernel_registry.size())); if (!res_method.ok()) { return res_method.error(); } diff --git a/extension/module/module.h b/extension/module/module.h index 2536fbc01ec..47ead23032e 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -27,6 +27,7 @@ namespace executorch { namespace extension { +using ET_RUNTIME_NAMESPACE::Kernel; using ET_RUNTIME_NAMESPACE::Method; using ET_RUNTIME_NAMESPACE::MethodMeta; using ET_RUNTIME_NAMESPACE::NamedDataMap; @@ -281,7 +282,8 @@ class Module { const std::string& method_name, runtime::HierarchicalAllocator* planned_memory = nullptr, torch::executor::EventTracer* event_tracer = nullptr, - const LoadBackendOptionsMap* backend_options = nullptr); + const LoadBackendOptionsMap* backend_options = nullptr, + std::vector kernel_registry = {}); ET_DEPRECATED ET_NODISCARD runtime::Error inline load_method( const std::string& method_name, @@ -329,9 +331,14 @@ class Module { ET_NODISCARD inline runtime::Error load_forward( runtime::HierarchicalAllocator* planned_memory = nullptr, torch::executor::EventTracer* event_tracer = nullptr, - const LoadBackendOptionsMap* backend_options = nullptr) { + const LoadBackendOptionsMap* backend_options = nullptr, + std::vector kernel_registry = {}) { return load_method( - "forward", planned_memory, event_tracer, backend_options); + "forward", + planned_memory, + event_tracer, + backend_options, + std::move(kernel_registry)); } ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward( @@ -724,6 +731,7 @@ class Module { std::unique_ptr planned_memory; std::unique_ptr memory_manager; std::unique_ptr method; + std::vector kernel_registry; }; std::string file_path_;