Skip to content

Commit 41200bd

Browse files
yichunkcopybara-github
authored andcommitted
Workaround Metal memory leak by using HostMemory for text and audio state context.
LiteRT-LM-PiperOrigin-RevId: 908794945
1 parent 567a173 commit 41200bd

4 files changed

Lines changed: 51 additions & 14 deletions

File tree

runtime/executor/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,9 +929,11 @@ cc_library(
929929
"@litert//litert/cc:litert_tensor_buffer_types",
930930
"//runtime/components:model_resources",
931931
"//runtime/engine:io_types",
932+
"//runtime/util:convert_tensor_buffer",
932933
"//runtime/util:file_util",
933934
"//runtime/util:litert_status_util",
934935
"//runtime/util:scoped_file",
936+
"//runtime/util:tensor_buffer_util",
935937
"@litert//tflite/types:half",
936938
] + select({
937939
"@litert//litert:litert_link_capi_so": [

runtime/executor/audio_litert_compiled_model_executor.cc

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,10 @@
5656
#include "runtime/executor/executor_settings_base.h"
5757
#include "runtime/executor/litert_compiled_model_executor_utils.h"
5858
#include "runtime/executor/llm_executor_io_types.h"
59+
#include "runtime/util/convert_tensor_buffer.h"
5960
#include "runtime/util/file_util.h"
6061
#include "runtime/util/scoped_file.h"
61-
#include "runtime/util/status_macros.h" //NOLINT
62+
#include "runtime/util/tensor_buffer_util.h"
6263
#include "tflite/types/half.h" // from @litert
6364

6465
namespace litert::lm {
@@ -838,8 +839,7 @@ AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::CreateNewContext() {
838839
// state.
839840
continue;
840841
}
841-
LITERT_ASSIGN_OR_RETURN(auto new_buffer, compiled_model_.CreateInputBuffer(
842-
signature.Key(), name));
842+
LITERT_ASSIGN_OR_RETURN(auto new_buffer, CopyTensorBuffer(env_, buffer));
843843
if (name == kPrevMaskName) {
844844
LITERT_ASSIGN_OR_RETURN(auto prev_mask_type, buffer.TensorType());
845845
LITERT_ASSIGN_OR_RETURN(int prev_mask_size,
@@ -860,15 +860,13 @@ absl::StatusOr<std::unique_ptr<AudioStreamingContext>>
860860
AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::CloneContext() {
861861
absl::flat_hash_map<absl::string_view, ::litert::TensorBuffer> state_buffers;
862862
LITERT_ASSIGN_OR_RETURN(auto signature, compiled_model_.GetSignature(0));
863-
for (auto& [name, buffer] : input_buffers_map_) {
863+
for (const auto& [name, buffer] : input_buffers_map_) {
864864
if (name == kSegmentValuesName || name == kSegmentMaskName) {
865865
// Skip the segment values and mask buffers as they are not part of the
866866
// state.
867867
continue;
868868
}
869-
LITERT_ASSIGN_OR_RETURN(auto new_buffer, compiled_model_.CreateInputBuffer(
870-
signature.Key(), name));
871-
RETURN_IF_ERROR(CopyBuffer(buffer, new_buffer));
869+
LITERT_ASSIGN_OR_RETURN(auto new_buffer, CopyTensorBuffer(env_, buffer));
872870
state_buffers[name] = std::move(new_buffer);
873871
}
874872
auto audio_streaming_context =
@@ -890,8 +888,28 @@ AudioLiteRtCompiledModelExecutor::AudioStreamingEncoder::RestoreContext(
890888
// state.
891889
continue;
892890
}
893-
LITERT_ASSIGN_OR_RETURN(auto buffer_copy, buffer.Duplicate());
894-
input_buffers_map_[name] = std::move(buffer_copy);
891+
892+
if (input_buffers_map_[name].IsMetalMemory()) {
893+
// b/505373949#comment13: A temporary fix for Metal memory leak.
894+
LITERT_ASSIGN_OR_RETURN(auto tensor_type, buffer.TensorType());
895+
if (tensor_type.ElementType() == ElementType::Float32) {
896+
LITERT_ASSIGN_OR_RETURN(auto data_span,
897+
ReferTensorBufferAsSpan<float>(buffer));
898+
LITERT_RETURN_IF_ERROR(
899+
input_buffers_map_[name].Write<float>(data_span));
900+
} else if (tensor_type.ElementType() == ElementType::Bool) {
901+
LITERT_ASSIGN_OR_RETURN(auto data_span,
902+
ReferTensorBufferAsSpan<bool>(buffer));
903+
LITERT_RETURN_IF_ERROR(input_buffers_map_[name].Write<bool>(data_span));
904+
} else {
905+
return absl::InvalidArgumentError(
906+
absl::StrCat("Unsupported element type for state buffer: ",
907+
tensor_type.ElementType()));
908+
}
909+
} else {
910+
LITERT_ASSIGN_OR_RETURN(auto buffer_copy, buffer.Duplicate());
911+
input_buffers_map_[name] = std::move(buffer_copy);
912+
}
895913
}
896914
return absl::OkStatus();
897915
}

runtime/util/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ cc_library(
350350
"@com_google_absl//absl/types:span",
351351
"@litert//litert/cc:litert_macros",
352352
"@litert//litert/cc:litert_ranked_tensor_type",
353+
"@litert//litert/cc:litert_tensor_buffer_types",
353354
] + select({
354355
"@litert//litert:litert_link_capi_so": [
355356
"@litert//litert/cc:litert_api_with_dynamic_runtime",

runtime/util/tensor_buffer_util.cc

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
#include "runtime/util/tensor_buffer_util.h"
1616

1717
#include <cstring>
18+
#include <memory>
1819
#include <utility>
1920
#include <vector>
2021

2122
#include "absl/status/statusor.h" // from @com_google_absl
2223
#include "litert/cc/litert_environment.h" // from @litert
2324
#include "litert/cc/litert_macros.h" // from @litert
2425
#include "litert/cc/litert_tensor_buffer.h" // from @litert
26+
#include "litert/cc/litert_tensor_buffer_types.h" // from @litert
2527

2628
namespace litert::lm {
2729

@@ -48,9 +50,23 @@ absl::StatusOr<::litert::TensorBuffer> CopyTensorBuffer(
4850
LITERT_ASSIGN_OR_RETURN(auto buffer_type, tensor_buffer.BufferType());
4951
LITERT_ASSIGN_OR_RETURN(auto size, tensor_buffer.PackedSize());
5052

51-
LITERT_ASSIGN_OR_RETURN(auto output_tensor_buffer,
52-
::litert::TensorBuffer::CreateManaged(
53-
env, buffer_type, tensor_type, size));
53+
std::unique_ptr<::litert::TensorBuffer> output_tensor_buffer;
54+
if (tensor_buffer.IsMetalMemory()) {
55+
// b/505373949#comment13: A temporary fix to create a host memory buffer to
56+
// copy from the metal memory buffer to avoid memory leak:
57+
LITERT_ASSIGN_OR_RETURN(
58+
auto buffer,
59+
::litert::TensorBuffer::CreateManaged(
60+
env, ::litert::TensorBufferType::kHostMemory, tensor_type, size));
61+
output_tensor_buffer =
62+
std::make_unique<::litert::TensorBuffer>(std::move(buffer));
63+
} else {
64+
LITERT_ASSIGN_OR_RETURN(
65+
auto buffer, ::litert::TensorBuffer::CreateManaged(env, buffer_type,
66+
tensor_type, size));
67+
output_tensor_buffer =
68+
std::make_unique<::litert::TensorBuffer>(std::move(buffer));
69+
}
5470

5571
LITERT_ASSIGN_OR_RETURN(
5672
auto src_lock_and_addr,
@@ -59,11 +75,11 @@ absl::StatusOr<::litert::TensorBuffer> CopyTensorBuffer(
5975
LITERT_ASSIGN_OR_RETURN(
6076
auto dst_lock_and_addr,
6177
::litert::TensorBufferScopedLock::Create(
62-
output_tensor_buffer, ::litert::TensorBuffer::LockMode::kWrite));
78+
*output_tensor_buffer, ::litert::TensorBuffer::LockMode::kWrite));
6379

6480
std::memcpy(dst_lock_and_addr.second, src_lock_and_addr.second, size);
6581

66-
return std::move(output_tensor_buffer);
82+
return std::move(*output_tensor_buffer);
6783
}
6884

6985
} // namespace litert::lm

0 commit comments

Comments
 (0)