5656#include " runtime/executor/executor_settings_base.h"
5757#include " runtime/executor/litert_compiled_model_executor_utils.h"
5858#include " runtime/executor/llm_executor_io_types.h"
59+ #include " runtime/util/convert_tensor_buffer.h"
5960#include " runtime/util/file_util.h"
6061#include " runtime/util/scoped_file.h"
61- #include " runtime/util/status_macros .h" // NOLINT
62+ #include " runtime/util/tensor_buffer_util .h"
6263#include " tflite/types/half.h" // from @litert
6364
6465namespace litert ::lm {
@@ -838,8 +839,7 @@ AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::CreateNewContext() {
838839 // state.
839840 continue ;
840841 }
841- LITERT_ASSIGN_OR_RETURN (auto new_buffer, compiled_model_.CreateInputBuffer (
842- signature.Key (), name));
842+ LITERT_ASSIGN_OR_RETURN (auto new_buffer, CopyTensorBuffer (env_, buffer));
843843 if (name == kPrevMaskName ) {
844844 LITERT_ASSIGN_OR_RETURN (auto prev_mask_type, buffer.TensorType ());
845845 LITERT_ASSIGN_OR_RETURN (int prev_mask_size,
@@ -860,15 +860,13 @@ absl::StatusOr<std::unique_ptr<AudioStreamingContext>>
860860AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::CloneContext () {
861861 absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer> state_buffers;
862862 LITERT_ASSIGN_OR_RETURN (auto signature, compiled_model_.GetSignature (0 ));
863- for (auto & [name, buffer] : input_buffers_map_) {
863+ for (const auto & [name, buffer] : input_buffers_map_) {
864864 if (name == kSegmentValuesName || name == kSegmentMaskName ) {
865865 // Skip the segment values and mask buffers as they are not part of the
866866 // state.
867867 continue ;
868868 }
869- LITERT_ASSIGN_OR_RETURN (auto new_buffer, compiled_model_.CreateInputBuffer (
870- signature.Key (), name));
871- RETURN_IF_ERROR (CopyBuffer (buffer, new_buffer));
869+ LITERT_ASSIGN_OR_RETURN (auto new_buffer, CopyTensorBuffer (env_, buffer));
872870 state_buffers[name] = std::move (new_buffer);
873871 }
874872 auto audio_streaming_context =
@@ -890,8 +888,28 @@ AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::RestoreContext(
890888 // state.
891889 continue ;
892890 }
893- LITERT_ASSIGN_OR_RETURN (auto buffer_copy, buffer.Duplicate ());
894- input_buffers_map_[name] = std::move (buffer_copy);
891+
892+ if (input_buffers_map_[name].IsMetalMemory ()) {
893+ // b/505373949#comment13: A temporary fix for Metal memory leak.
894+ LITERT_ASSIGN_OR_RETURN (auto tensor_type, buffer.TensorType ());
895+ if (tensor_type.ElementType () == ElementType::Float32) {
896+ LITERT_ASSIGN_OR_RETURN (auto data_span,
897+ ReferTensorBufferAsSpan<float >(buffer));
898+ LITERT_RETURN_IF_ERROR (
899+ input_buffers_map_[name].Write <float >(data_span));
900+ } else if (tensor_type.ElementType () == ElementType::Bool) {
901+ LITERT_ASSIGN_OR_RETURN (auto data_span,
902+ ReferTensorBufferAsSpan<bool >(buffer));
903+ LITERT_RETURN_IF_ERROR (input_buffers_map_[name].Write <bool >(data_span));
904+ } else {
905+ return absl::InvalidArgumentError (
906+ absl::StrCat (" Unsupported element type for state buffer: " ,
907+ tensor_type.ElementType ()));
908+ }
909+ } else {
910+ LITERT_ASSIGN_OR_RETURN (auto buffer_copy, buffer.Duplicate ());
911+ input_buffers_map_[name] = std::move (buffer_copy);
912+ }
895913 }
896914 return absl::OkStatus ();
897915}
0 commit comments