From 9eef6dc150ddae3f7f376865af0691805cd1e031 Mon Sep 17 00:00:00 2001 From: "Yuxuan (William) Liu" Date: Tue, 19 May 2026 17:44:24 -0700 Subject: [PATCH] Thread kernel_registry through Module::load_method (#19641) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Previous commit added a method-scoped kernel registry to Program::load_method and Method, allowing callers to override specific kernels for a single method without affecting the global registry. However, the Module facade class did not expose this parameter, forcing consumers to bypass Module and manage memory manually. This adds an optional `Span kernel_registry` parameter (defaulting to empty) to Module::load_method and Module::load_forward, and forwards it to Program::load_method. Existing callers are completely unaffected — the default empty span causes the runtime to fall back to the global kernel registry, exactly as before. Reviewed By: JacobSzwejbka Differential Revision: D104433196 --- extension/module/module.cpp | 10 ++++++++-- extension/module/module.h | 14 +++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 459b062fd67..5422fb15b71 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -20,6 +20,7 @@ namespace extension { namespace ET_MODULE_NAMESPACE { using ET_MERGED_DATA_MAP_NAMESPACE::MergedDataMap; +using ET_RUNTIME_NAMESPACE::Kernel; using ET_RUNTIME_NAMESPACE::MethodMeta; using ET_RUNTIME_NAMESPACE::Program; @@ -406,7 +407,8 @@ runtime::Error Module::load_method( const std::string& method_name, runtime::HierarchicalAllocator* planned_memory, torch::executor::EventTracer* event_tracer, - const LoadBackendOptionsMap* backend_options) { + const LoadBackendOptionsMap* backend_options, + std::vector kernel_registry) { if (!is_method_loaded(method_name)) { ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -446,12 +448,16 @@ runtime::Error Module::load_method( method_holder.memory_manager = std::make_unique( memory_allocator_.get(), planned_memory, temp_allocator_.get()); + method_holder.kernel_registry = std::move(kernel_registry); auto res_method = program_->load_method( method_name.c_str(), method_holder.memory_manager.get(), event_tracer ? event_tracer : this->event_tracer(), merged_data_map_.get(), - effective_backend_options); + effective_backend_options, + runtime::Span( + method_holder.kernel_registry.data(), + method_holder.kernel_registry.size())); if (!res_method.ok()) { return res_method.error(); } diff --git a/extension/module/module.h b/extension/module/module.h index 2536fbc01ec..47ead23032e 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -27,6 +27,7 @@ namespace executorch { namespace extension { +using ET_RUNTIME_NAMESPACE::Kernel; using ET_RUNTIME_NAMESPACE::Method; using ET_RUNTIME_NAMESPACE::MethodMeta; using ET_RUNTIME_NAMESPACE::NamedDataMap; @@ -281,7 +282,8 @@ class Module { const std::string& method_name, runtime::HierarchicalAllocator* planned_memory = nullptr, torch::executor::EventTracer* event_tracer = nullptr, - const LoadBackendOptionsMap* backend_options = nullptr); + const LoadBackendOptionsMap* backend_options = nullptr, + std::vector kernel_registry = {}); ET_DEPRECATED ET_NODISCARD runtime::Error inline load_method( const std::string& method_name, @@ -329,9 +331,14 @@ class Module { ET_NODISCARD inline runtime::Error load_forward( runtime::HierarchicalAllocator* planned_memory = nullptr, torch::executor::EventTracer* event_tracer = nullptr, - const LoadBackendOptionsMap* backend_options = nullptr) { + const LoadBackendOptionsMap* backend_options = nullptr, + std::vector kernel_registry = {}) { return load_method( - "forward", planned_memory, event_tracer, backend_options); + "forward", + planned_memory, + event_tracer, + backend_options, + std::move(kernel_registry)); } ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward( @@ -724,6 +731,7 @@ class Module { std::unique_ptr planned_memory; std::unique_ptr memory_manager; std::unique_ptr method; + std::vector kernel_registry; }; std::string file_path_;