-
Notifications
You must be signed in to change notification settings - Fork 1k
Cortex-M backend: Add AoT scratch-buffer planning. #19636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fae7030
936ed4b
699e731
23c86c1
264fc9e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * Copyright 2026 Arm Limited and/or its affiliates. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
|
|
@@ -71,6 +72,7 @@ Tensor& quantized_batch_matmul_out( | |
| int64_t output_offset, | ||
| int64_t output_multiplier, | ||
| int64_t output_shift, | ||
| const Tensor& scratch, | ||
| Tensor& out) { | ||
| if (!validate_batch_matmul_arguments(context, lhs, rhs_transposed, out)) { | ||
| return out; | ||
|
|
@@ -100,25 +102,26 @@ Tensor& quantized_batch_matmul_out( | |
| quant_params.multiplier = static_cast<int32_t>(output_multiplier); | ||
| quant_params.shift = static_cast<int32_t>(output_shift); | ||
|
|
||
| const int32_t buf_size = arm_fully_connected_s8_get_buffer_size(&out_dims); | ||
|
|
||
| cmsis_nn_context ctx; | ||
| ctx.buf = nullptr; | ||
| ctx.size = 0; | ||
|
|
||
| if (buf_size > 0) { | ||
| auto buffer_or_error = context.allocate_temp(buf_size); | ||
| if (!buffer_or_error.ok()) { | ||
| ET_LOG( | ||
| Error, | ||
| "quantized_batch_matmul: failed to allocate scratch buffer (%d bytes)", | ||
| buf_size); | ||
| context.fail(buffer_or_error.error()); | ||
| return out; | ||
| } | ||
| ctx.buf = buffer_or_error.get(); | ||
| ctx.size = buf_size; | ||
| ctx.size = scratch.nbytes(); | ||
| if (ctx.size > 0) { | ||
| ctx.buf = scratch.mutable_data_ptr<int8_t>(); | ||
| } | ||
|
|
||
| #ifdef CORTEX_M_ENABLE_ASSERTS | ||
| const int32_t runtime_buffer_bytes = | ||
| arm_fully_connected_s8_get_buffer_size(&out_dims); | ||
| if (ctx.size != static_cast<size_t>(runtime_buffer_bytes)) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. WDYT about doing
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To me, this is a correctness assert. We don't just want to avoid failure, we want to make sure to ensure correctness. |
||
| ET_LOG( | ||
| Error, | ||
| "quantized_batch_matmul: scratch buffer size incorrect - actual: (%d) needed: (%d)", | ||
| static_cast<int>(ctx.size), | ||
| runtime_buffer_bytes); | ||
| context.fail(Error::Internal); | ||
| return out; | ||
| } | ||
| #endif | ||
|
|
||
| const arm_cmsis_nn_status status = arm_batch_matmul_s8( | ||
| &ctx, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As much as I want to limit the runtime checks, I think it'd be good to have this check always on and non-optional. Without this, we could wind up writing past end of buffers.
Also, naming nit: technically this is not an assert, as it should not crash the program if it fails. Maybe ENABLE_RUNTIME_CHECKS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is that we should be confident that we are doing the correct allocation after testing. Users can turn this on to verify for example that they have not mixed up cmsis_nn versions, but then skip it in production. That's also why I want it to be a crash. If there is a mismatch here, I want to enforce a fix. Also, when we have this flag available, we can use it in more places.