@@ -845,9 +845,28 @@ absl::Status LlmLiteRtNpuCompiledModelExecutor::AllocateTransformerBuffers(
845845 absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>&
846846 decode_output_kv_cache_slice_buffers,
847847 absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer>&
848- verify_output_kv_cache_slice_buffers) {
848+ verify_output_kv_cache_slice_buffers,
849+ absl::flat_hash_map<absl::string_view, HWQuantParams>& kv_quant_params) {
849850 auto prefill_signature = transformer_model->FindSignature (kPrefillSignature );
850851
852+ if (prefill_signature.HasValue ()) {
853+ for (auto output_name : prefill_signature->OutputNames ()) {
854+ if (absl::StartsWith (output_name, kv_cache_slice_k_root_name) ||
855+ absl::StartsWith (output_name, kv_cache_slice_v_root_name)) {
856+ auto tensor_expected = prefill_signature->OutputTensor (output_name);
857+ if (tensor_expected.HasValue ()) {
858+ HWQuantParams q_params;
859+ if (tensor_expected->HasQuantization ()) {
860+ auto pq = tensor_expected->PerTensorQuantization ();
861+ q_params.scale = pq.scale ;
862+ q_params.zero_point = pq.zero_point ;
863+ }
864+ kv_quant_params[output_name] = q_params;
865+ }
866+ }
867+ }
868+ }
869+
851870 // Create input buffers for prefill signature.
852871 for (auto input_name : prefill_signature->InputNames ()) {
853872 if (absl::StartsWith (input_name, kv_cache_k_root_name) ||
@@ -1986,7 +2005,8 @@ absl::Status LlmLiteRtNpuCompiledModelExecutor::PrefillInternal(
19862005 if (prefill_kv_cache_update_method_ == KVCacheUpdateMethod::kWH ) {
19872006 RETURN_IF_ERROR (HWKVCacheUpdate (
19882007 cache_update_inference_context_.prefill_input_buffers ,
1989- cache_update_inference_context_.prefill_output_buffers ));
2008+ cache_update_inference_context_.prefill_output_buffers ,
2009+ kv_quant_params_));
19902010 } else {
19912011 auto res = npu_auxiliary_context_.npu_auxiliary_compiled_model .Run (
19922012 CacheUpdateSignatures::kPrefillCacheUpdate ,
@@ -2162,9 +2182,10 @@ absl::Status LlmLiteRtNpuCompiledModelExecutor::DecodeInternal(
21622182 {
21632183 auto start = absl::Now ();
21642184 if (decode_kv_cache_update_method_ == KVCacheUpdateMethod::kWH ) {
2165- RETURN_IF_ERROR (HWKVCacheUpdate (
2166- cache_update_inference_context_.decode_input_buffers ,
2167- cache_update_inference_context_.decode_output_buffers ));
2185+ RETURN_IF_ERROR (
2186+ HWKVCacheUpdate (cache_update_inference_context_.decode_input_buffers ,
2187+ cache_update_inference_context_.decode_output_buffers ,
2188+ kv_quant_params_));
21682189 } else {
21692190 auto res = npu_auxiliary_context_.npu_auxiliary_compiled_model .Run (
21702191 CacheUpdateSignatures::kDecodeCacheUpdate ,
@@ -2628,7 +2649,8 @@ absl::Status LlmLiteRtNpuCompiledModelExecutor::CommitVerifiedKVCache(
26282649 if (prefill_kv_cache_update_method_ == KVCacheUpdateMethod::kWH ) {
26292650 RETURN_IF_ERROR (
26302651 HWKVCacheUpdate (cache_update_inference_context_.verify_input_buffers ,
2631- cache_update_inference_context_.verify_output_buffers ));
2652+ cache_update_inference_context_.verify_output_buffers ,
2653+ kv_quant_params_));
26322654 } else {
26332655 LITERT_RETURN_IF_ERROR (
26342656 npu_auxiliary_context_.npu_auxiliary_compiled_model .Run (
@@ -2742,12 +2764,13 @@ LlmLiteRtNpuCompiledModelExecutor::CreateForModelHasPerLayerEmbedding(
27422764 absl::flat_hash_map<absl::string_view, TensorBuffer>
27432765 verify_output_kv_cache_slice_buffers;
27442766
2767+ absl::flat_hash_map<absl::string_view, HWQuantParams> kv_quant_params;
27452768 RETURN_IF_ERROR (AllocateTransformerBuffers (
27462769 env, transformer_model, llm_compiled_model, gemma_prefill_input_buffers,
27472770 gemma_decode_input_buffers, gemma_verify_input_buffers,
27482771 input_kv_cache_buffers, prefill_output_kv_cache_slice_buffers,
27492772 decode_output_kv_cache_slice_buffers,
2750- verify_output_kv_cache_slice_buffers));
2773+ verify_output_kv_cache_slice_buffers, kv_quant_params ));
27512774
27522775 // Gemma3n specific fix: KV cache buffer 19 of *prefill* is not connected
27532776 // to any OPs in the model, making the LiteRT runtime allocate host memory
@@ -2978,8 +3001,8 @@ LlmLiteRtNpuCompiledModelExecutor::CreateForModelHasPerLayerEmbedding(
29783001 std::move (embedder_per_layer_context), quantization_params,
29793002 std::move (ple_table_ptrs), std::move (ple_quant_params),
29803003 std::move (ple_per_tensor_scales), table_count, output_type, final_scale,
2981- final_zero_point, speculative_decoding_type, std::move (drafter_context) ,
2982- std::move (drafter_aux_context)));
3004+ final_zero_point, std::move (kv_quant_params), speculative_decoding_type ,
3005+ std::move (drafter_context), std::move ( drafter_aux_context)));
29833006 return executor;
29843007}
29853008
@@ -3013,12 +3036,13 @@ LlmLiteRtNpuCompiledModelExecutor::CreateForModelWithoutPerLayerEmbedding(
30133036 absl::flat_hash_map<absl::string_view, TensorBuffer>
30143037 verify_output_kv_cache_slice_buffers;
30153038
3039+ absl::flat_hash_map<absl::string_view, HWQuantParams> kv_quant_params;
30163040 RETURN_IF_ERROR (AllocateTransformerBuffers (
30173041 env, transformer_model, llm_compiled_model, gemma_prefill_input_buffers,
30183042 gemma_decode_input_buffers, gemma_verify_input_buffers,
30193043 input_kv_cache_buffers, prefill_output_kv_cache_slice_buffers,
30203044 decode_output_kv_cache_slice_buffers,
3021- verify_output_kv_cache_slice_buffers));
3045+ verify_output_kv_cache_slice_buffers, kv_quant_params ));
30223046 LITERT_ASSIGN_OR_RETURN (
30233047 auto llm_inference_context,
30243048 CreateLlmInferenceContextWithBufferSharing (
@@ -3182,8 +3206,9 @@ LlmLiteRtNpuCompiledModelExecutor::CreateForModelWithoutPerLayerEmbedding(
31823206 std::move (cache_update_inference_context), std::move (prefill_runner_set),
31833207 std::move (maybe_embedding_lookup_manager),
31843208 /* embedder_per_layer_context=*/ std::nullopt , quantization_params, {}, {},
3185- {}, 0 , litert::ElementType::None, 1 .0f , 0 , speculative_decoding_type,
3186- std::move (drafter_context), std::move (drafter_aux_context)));
3209+ {}, 0 , litert::ElementType::None, 1 .0f , 0 , std::move (kv_quant_params),
3210+ speculative_decoding_type, std::move (drafter_context),
3211+ std::move (drafter_aux_context)));
31873212 return executor;
31883213}
31893214
0 commit comments