Skip to content

Commit d8d7dc7

Browse files
YunanAZcopybara-github
authored andcommitted
Fix flaky HW cache update kernel by adding dequantization support and robust validation.
LiteRT-LM-PiperOrigin-RevId: 910140717
1 parent 4541df0 commit d8d7dc7

6 files changed

Lines changed: 391 additions & 42 deletions

runtime/executor/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,12 +543,14 @@ cc_library(
543543
"@com_google_absl//absl/status",
544544
"@com_google_absl//absl/status:statusor",
545545
"@com_google_absl//absl/strings",
546+
"@litert//litert/c:litert_layout",
546547
] + select({
547548
"@litert//litert:litert_link_capi_so": [
548549
"@litert//litert/cc:litert_api_with_dynamic_runtime",
549550
],
550551
"//conditions:default": [
551552
"@litert//litert/cc:litert_element_type",
553+
"@litert//litert/cc:litert_layout",
552554
"@litert//litert/cc:litert_macros",
553555
"@litert//litert/cc:litert_ranked_tensor_type",
554556
"@litert//litert/cc:litert_tensor_buffer",

runtime/executor/llm_litert_npu_compiled_model_executor.cc

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -845,9 +845,28 @@ absl::Status LlmLiteRtNpuCompiledModelExecutor::AllocateTransformerBuffers(
845845
absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>&
846846
decode_output_kv_cache_slice_buffers,
847847
absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>&
848-
verify_output_kv_cache_slice_buffers) {
848+
verify_output_kv_cache_slice_buffers,
849+
absl::flat_hash_map<absl::string_view, HWQuantParams>& kv_quant_params) {
849850
auto prefill_signature = transformer_model->FindSignature(kPrefillSignature);
850851

852+
if (prefill_signature.HasValue()) {
853+
for (auto output_name : prefill_signature->OutputNames()) {
854+
if (absl::StartsWith(output_name, kv_cache_slice_k_root_name) ||
855+
absl::StartsWith(output_name, kv_cache_slice_v_root_name)) {
856+
auto tensor_expected = prefill_signature->OutputTensor(output_name);
857+
if (tensor_expected.HasValue()) {
858+
HWQuantParams q_params;
859+
if (tensor_expected->HasQuantization()) {
860+
auto pq = tensor_expected->PerTensorQuantization();
861+
q_params.scale = pq.scale;
862+
q_params.zero_point = pq.zero_point;
863+
}
864+
kv_quant_params[output_name] = q_params;
865+
}
866+
}
867+
}
868+
}
869+
851870
// Create input buffers for prefill signature.
852871
for (auto input_name : prefill_signature->InputNames()) {
853872
if (absl::StartsWith(input_name, kv_cache_k_root_name) ||
@@ -1986,7 +2005,8 @@ absl::Status LlmLiteRtNpuCompiledModelExecutor::PrefillInternal(
19862005
if (prefill_kv_cache_update_method_ == KVCacheUpdateMethod::kWH) {
19872006
RETURN_IF_ERROR(HWKVCacheUpdate(
19882007
cache_update_inference_context_.prefill_input_buffers,
1989-
cache_update_inference_context_.prefill_output_buffers));
2008+
cache_update_inference_context_.prefill_output_buffers,
2009+
kv_quant_params_));
19902010
} else {
19912011
auto res = npu_auxiliary_context_.npu_auxiliary_compiled_model.Run(
19922012
CacheUpdateSignatures::kPrefillCacheUpdate,
@@ -2162,9 +2182,10 @@ absl::Status LlmLiteRtNpuCompiledModelExecutor::DecodeInternal(
21622182
{
21632183
auto start = absl::Now();
21642184
if (decode_kv_cache_update_method_ == KVCacheUpdateMethod::kWH) {
2165-
RETURN_IF_ERROR(HWKVCacheUpdate(
2166-
cache_update_inference_context_.decode_input_buffers,
2167-
cache_update_inference_context_.decode_output_buffers));
2185+
RETURN_IF_ERROR(
2186+
HWKVCacheUpdate(cache_update_inference_context_.decode_input_buffers,
2187+
cache_update_inference_context_.decode_output_buffers,
2188+
kv_quant_params_));
21682189
} else {
21692190
auto res = npu_auxiliary_context_.npu_auxiliary_compiled_model.Run(
21702191
CacheUpdateSignatures::kDecodeCacheUpdate,
@@ -2628,7 +2649,8 @@ absl::Status LlmLiteRtNpuCompiledModelExecutor::CommitVerifiedKVCache(
26282649
if (prefill_kv_cache_update_method_ == KVCacheUpdateMethod::kWH) {
26292650
RETURN_IF_ERROR(
26302651
HWKVCacheUpdate(cache_update_inference_context_.verify_input_buffers,
2631-
cache_update_inference_context_.verify_output_buffers));
2652+
cache_update_inference_context_.verify_output_buffers,
2653+
kv_quant_params_));
26322654
} else {
26332655
LITERT_RETURN_IF_ERROR(
26342656
npu_auxiliary_context_.npu_auxiliary_compiled_model.Run(
@@ -2742,12 +2764,13 @@ LlmLiteRtNpuCompiledModelExecutor::CreateForModelHasPerLayerEmbedding(
27422764
absl::flat_hash_map<absl::string_view, TensorBuffer>
27432765
verify_output_kv_cache_slice_buffers;
27442766

2767+
absl::flat_hash_map<absl::string_view, HWQuantParams> kv_quant_params;
27452768
RETURN_IF_ERROR(AllocateTransformerBuffers(
27462769
env, transformer_model, llm_compiled_model, gemma_prefill_input_buffers,
27472770
gemma_decode_input_buffers, gemma_verify_input_buffers,
27482771
input_kv_cache_buffers, prefill_output_kv_cache_slice_buffers,
27492772
decode_output_kv_cache_slice_buffers,
2750-
verify_output_kv_cache_slice_buffers));
2773+
verify_output_kv_cache_slice_buffers, kv_quant_params));
27512774

27522775
// Gemma3n specific fix: KV cache buffer 19 of *prefill* is not connected
27532776
// to any OPs in the model, making the LiteRT runtime allocate host memory
@@ -2978,8 +3001,8 @@ LlmLiteRtNpuCompiledModelExecutor::CreateForModelHasPerLayerEmbedding(
29783001
std::move(embedder_per_layer_context), quantization_params,
29793002
std::move(ple_table_ptrs), std::move(ple_quant_params),
29803003
std::move(ple_per_tensor_scales), table_count, output_type, final_scale,
2981-
final_zero_point, speculative_decoding_type, std::move(drafter_context),
2982-
std::move(drafter_aux_context)));
3004+
final_zero_point, std::move(kv_quant_params), speculative_decoding_type,
3005+
std::move(drafter_context), std::move(drafter_aux_context)));
29833006
return executor;
29843007
}
29853008

@@ -3013,12 +3036,13 @@ LlmLiteRtNpuCompiledModelExecutor::CreateForModelWithoutPerLayerEmbedding(
30133036
absl::flat_hash_map<absl::string_view, TensorBuffer>
30143037
verify_output_kv_cache_slice_buffers;
30153038

3039+
absl::flat_hash_map<absl::string_view, HWQuantParams> kv_quant_params;
30163040
RETURN_IF_ERROR(AllocateTransformerBuffers(
30173041
env, transformer_model, llm_compiled_model, gemma_prefill_input_buffers,
30183042
gemma_decode_input_buffers, gemma_verify_input_buffers,
30193043
input_kv_cache_buffers, prefill_output_kv_cache_slice_buffers,
30203044
decode_output_kv_cache_slice_buffers,
3021-
verify_output_kv_cache_slice_buffers));
3045+
verify_output_kv_cache_slice_buffers, kv_quant_params));
30223046
LITERT_ASSIGN_OR_RETURN(
30233047
auto llm_inference_context,
30243048
CreateLlmInferenceContextWithBufferSharing(
@@ -3182,8 +3206,9 @@ LlmLiteRtNpuCompiledModelExecutor::CreateForModelWithoutPerLayerEmbedding(
31823206
std::move(cache_update_inference_context), std::move(prefill_runner_set),
31833207
std::move(maybe_embedding_lookup_manager),
31843208
/*embedder_per_layer_context=*/std::nullopt, quantization_params, {}, {},
3185-
{}, 0, litert::ElementType::None, 1.0f, 0, speculative_decoding_type,
3186-
std::move(drafter_context), std::move(drafter_aux_context)));
3209+
{}, 0, litert::ElementType::None, 1.0f, 0, std::move(kv_quant_params),
3210+
speculative_decoding_type, std::move(drafter_context),
3211+
std::move(drafter_aux_context)));
31873212
return executor;
31883213
}
31893214

runtime/executor/llm_litert_npu_compiled_model_executor.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ class LlmLiteRtNpuCompiledModelExecutor : public LlmExecutor {
310310
std::vector<float> ple_per_tensor_scales = {}, int num_tables = 0,
311311
litert::ElementType output_type = litert::ElementType::None,
312312
float final_scale = 1.0f, int32_t final_zero_point = 0,
313+
absl::flat_hash_map<absl::string_view, HWQuantParams> kv_quant_params =
314+
{},
313315
SpeculativeDecodingType speculative_decoding_type =
314316
SpeculativeDecodingType::kNone,
315317
std::optional<DrafterContext> drafter_context = std::nullopt,
@@ -327,6 +329,7 @@ class LlmLiteRtNpuCompiledModelExecutor : public LlmExecutor {
327329
cache_update_inference_context_(
328330
std::move(cache_update_inference_context)),
329331
prefill_signature_map_(std::move(prefill_signature_map)),
332+
kv_quant_params_(std::move(kv_quant_params)),
330333
ple_table_ptrs_(std::move(ple_table_ptrs)),
331334
ple_quant_params_(std::move(ple_quant_params)),
332335
ple_per_tensor_scales_(std::move(ple_per_tensor_scales)),
@@ -575,7 +578,8 @@ class LlmLiteRtNpuCompiledModelExecutor : public LlmExecutor {
575578
absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>&
576579
decode_output_kv_cache_slice_buffers,
577580
absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>&
578-
verify_output_kv_cache_slice_buffers);
581+
verify_output_kv_cache_slice_buffers,
582+
absl::flat_hash_map<absl::string_view, HWQuantParams>& kv_quant_params);
579583

580584
// Create the executor for Gemma3n, with multi-modality support.
581585
static absl::StatusOr<std::unique_ptr<LlmLiteRtNpuCompiledModelExecutor>>
@@ -616,6 +620,7 @@ class LlmLiteRtNpuCompiledModelExecutor : public LlmExecutor {
616620
InferenceContext cache_update_inference_context_;
617621
SortedPrefillSignatureMap prefill_signature_map_;
618622

623+
absl::flat_hash_map<absl::string_view, HWQuantParams> kv_quant_params_;
619624
bool use_hw_ple_for_npu_ = false;
620625
std::vector<const uint8_t*> ple_table_ptrs_;
621626
std::vector<HWQuantizationParams> ple_quant_params_;

0 commit comments

Comments
 (0)