Skip to content

Commit d6e36f2

Browse files
ai-edge-botcopybara-github
authored andcommitted
This is an internal change.
LiteRT-LM-PiperOrigin-RevId: 900751205
1 parent 1789e40 commit d6e36f2

33 files changed

Lines changed: 814 additions & 361 deletions

runtime/components/embedding_lookup/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ cc_library(
2828
deps = [
2929
"@com_google_absl//absl/status",
3030
"@com_google_absl//absl/types:span",
31+
"@litert//litert/cc:litert_expected",
3132
] + select({
3233
"@litert//litert:litert_link_capi_so": [
3334
"@litert//litert/cc:litert_api_with_dynamic_runtime",
@@ -51,6 +52,8 @@ cc_library(
5152
"@com_google_absl//absl/strings",
5253
"@com_google_absl//absl/strings:string_view",
5354
"@com_google_absl//absl/types:span",
55+
"@litert//litert/c:litert_common",
56+
"@litert//litert/cc:litert_expected",
5457
"//runtime/util:litert_status_util",
5558
] + select({
5659
"@litert//litert:litert_link_capi_so": [
@@ -152,6 +155,8 @@ cc_library(
152155
"@com_google_absl//absl/status:statusor",
153156
"@com_google_absl//absl/strings:string_view",
154157
"@com_google_absl//absl/types:span",
158+
"@litert//litert/c:litert_common",
159+
"@litert//litert/cc:litert_expected",
155160
"@litert//litert/cc:litert_macros",
156161
"//runtime/executor:llm_executor_io_types",
157162
"//runtime/util:litert_status_util",

runtime/components/embedding_lookup/embedding_lookup.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include "absl/status/status.h" // from @com_google_absl
2424
#include "absl/types/span.h" // from @com_google_absl
25+
#include "litert/cc/litert_expected.h" // from @litert
2526
#include "litert/cc/litert_tensor_buffer.h" // from @litert
2627

2728
namespace litert::lm {
@@ -69,6 +70,9 @@ class EmbeddingLookup {
6970
virtual absl::Status LookupPrefill(absl::Span<const int> tokens,
7071
litert::TensorBuffer* output_tensor,
7172
size_t byte_offset) = 0;
73+
74+
// Returns whether the embedding lookup compiled model is fully accelerated.
75+
virtual litert::Expected<bool> IsFullyAccelerated() = 0;
7276
};
7377

7478
} // namespace litert::lm

runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ EndOfMultiModalEmbedding::Create(litert::Environment& env,
154154
const litert::Model* absl_nonnull model,
155155
int special_token) {
156156
auto handler = std::unique_ptr<EndOfMultiModalEmbedding>(
157-
new EndOfMultiModalEmbedding(env, model, special_token));
157+
new EndOfMultiModalEmbedding(env, special_token));
158158
RETURN_IF_ERROR( // IWYU pragma: keep as is included by status_macros.h
159-
handler->Initialize());
159+
handler->Initialize(*model));
160160
return handler;
161161
}
162162

163-
absl::Status EndOfMultiModalEmbedding::Initialize() {
163+
absl::Status EndOfMultiModalEmbedding::Initialize(const Model& model) {
164164
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
165165
#if defined(__ANDROID__)
166166
options.SetHardwareAccelerators(litert::HwAccelerators::kNpu |
@@ -179,8 +179,8 @@ absl::Status EndOfMultiModalEmbedding::Initialize() {
179179

180180
LITERT_ASSIGN_OR_RETURN(
181181
litert::CompiledModel compiled_model,
182-
litert::CompiledModel::Create(env_, model_.Get(), options));
183-
if (auto num_signatures = model_.GetNumSignatures(); num_signatures != 1) {
182+
litert::CompiledModel::Create(env_, model.Get(), options));
183+
if (auto num_signatures = model.GetNumSignatures(); num_signatures != 1) {
184184
return absl::InvalidArgumentError(absl::StrCat(
185185
"The Embedding model must have exactly one signature but got ",
186186
num_signatures));
@@ -237,6 +237,10 @@ absl::Status EndOfMultiModalEmbedding::Initialize() {
237237
size_t bytes = end_of_multi_modal_embedding_.size() * sizeof(float);
238238
output_buffers[0].Read(absl::MakeSpan(data_ptr, bytes));
239239

240+
if (auto res = compiled_model.IsFullyAccelerated(); res.HasValue()) {
241+
is_fully_accelerated_ = *res;
242+
}
243+
240244
return absl::OkStatus();
241245
}
242246

runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
#include <cstddef>
2121
#include <memory>
22-
#include <utility>
2322
#include <vector>
2423

2524
#include "absl/base/nullability.h" // from @com_google_absl
@@ -29,7 +28,6 @@
2928
#include "litert/cc/litert_environment.h" // from @litert
3029
#include "litert/cc/litert_layout.h" // from @litert
3130
#include "litert/cc/litert_model.h" // from @litert
32-
#include "litert/cc/litert_options.h" // from @litert
3331
#include "litert/cc/litert_tensor_buffer.h" // from @litert
3432
#include "runtime/components/embedding_lookup/embedding_lookup.h"
3533

@@ -74,20 +72,19 @@ class EndOfMultiModalEmbedding : public EmbeddingLookup {
7472
litert::TensorBuffer* prefill_output,
7573
size_t byte_offset) override;
7674

75+
litert::Expected<bool> IsFullyAccelerated() override {
76+
return is_fully_accelerated_;
77+
}
78+
7779
protected:
78-
EndOfMultiModalEmbedding(litert::Environment& env,
79-
const litert::Model* absl_nonnull model,
80-
int special_token)
81-
: env_(env), model_(*model), special_token_(special_token) {}
80+
EndOfMultiModalEmbedding(litert::Environment& env, int special_token)
81+
: env_(env), special_token_(special_token) {}
8282

8383
// Loads the provided model. This must be called before Lookup functions.
84-
absl::Status Initialize();
84+
absl::Status Initialize(const litert::Model& model);
8585

8686
// The environment for the embedding lookup.
8787
litert::Environment& env_;
88-
// The model for the embedding lookup. The actual model instance is owned by
89-
// the model resources.
90-
const litert::Model& model_;
9188

9289
// The layout of the output tensor from the embedding model.
9390
litert::Layout output_buffer_layout_;
@@ -99,6 +96,8 @@ class EndOfMultiModalEmbedding : public EmbeddingLookup {
9996
// Contains the end of multi-modal embedding that was looked up from the
10097
// model.
10198
std::vector<float> end_of_multi_modal_embedding_;
99+
100+
bool is_fully_accelerated_ = false;
102101
};
103102

104103
} // namespace litert::lm

runtime/components/embedding_lookup/embedding_lookup_manager.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
#include "absl/status/statusor.h" // from @com_google_absl
3030
#include "absl/strings/string_view.h" // from @com_google_absl
3131
#include "absl/types/span.h" // from @com_google_absl
32+
#include "litert/c/litert_common.h" // from @litert
3233
#include "litert/cc/litert_environment.h" // from @litert
34+
#include "litert/cc/litert_expected.h" // from @litert
3335
#include "litert/cc/litert_macros.h" // from @litert
3436
#include "litert/cc/litert_model.h" // from @litert
3537
#include "litert/cc/litert_tensor_buffer.h" // from @litert
@@ -265,4 +267,21 @@ absl::Status EmbeddingLookupManager::Initialize(
265267
return absl::OkStatus();
266268
}
267269

270+
litert::Expected<bool> EmbeddingLookupManager::IsFullyAccelerated() const {
271+
if (text_embedding_lookup_ == nullptr) {
272+
return litert::Error(kLiteRtStatusErrorRuntimeFailure,
273+
"Text embedding lookup has not been created.");
274+
}
275+
if (auto res = text_embedding_lookup_->IsFullyAccelerated();
276+
!res.HasValue() || !*res) {
277+
return res;
278+
}
279+
for (const auto& lookup : end_of_multi_modal_embedding_lookups_) {
280+
if (auto res = lookup->IsFullyAccelerated(); !res.HasValue() || !*res) {
281+
return res;
282+
}
283+
}
284+
return true;
285+
}
286+
268287
} // namespace litert::lm

runtime/components/embedding_lookup/embedding_lookup_manager.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "absl/status/statusor.h" // from @com_google_absl
2828
#include "absl/types/span.h" // from @com_google_absl
2929
#include "litert/cc/litert_environment.h" // from @litert
30+
#include "litert/cc/litert_expected.h" // from @litert
3031
#include "litert/cc/litert_model.h" // from @litert
3132
#include "litert/cc/litert_tensor_buffer.h" // from @litert
3233
#include "runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.h"
@@ -116,6 +117,8 @@ class EmbeddingLookupManager {
116117
return text_embedding_lookup_.get();
117118
}
118119

120+
litert::Expected<bool> IsFullyAccelerated() const;
121+
119122
protected:
120123
absl::Status Initialize(
121124
litert::Environment& env,

runtime/components/embedding_lookup/embedding_lookup_multi_modal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class EmbeddingLookupMultiModal : public EmbeddingLookup {
9191
// Returns true if there are any embeddings left to be read.
9292
bool HasRemainingEmbeddings() const { return embedding_.size() > 0; }
9393

94+
litert::Expected<bool> IsFullyAccelerated() override { return true; }
95+
9496
protected:
9597
absl::Status Initialize(const ::litert::TensorBuffer* embedding_buffer,
9698
int special_token);

runtime/components/embedding_lookup/embedding_lookup_text.cc

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@
3131
#include "absl/strings/string_view.h" // from @com_google_absl
3232
#include "absl/types/span.h" // from @com_google_absl
3333
#include "litert/cc/litert_common.h" // from @litert
34-
#include "litert/cc/litert_compiled_model.h" // from @litert
35-
#include "litert/cc/litert_element_type.h" // from @litert
36-
#include "litert/cc/litert_environment.h" // from @litert
37-
#include "litert/cc/litert_macros.h" // from @litert
34+
#include "litert/cc/litert_expected.h" // from @litert
3835
#include "litert/cc/litert_model.h" // from @litert
3936
#include "litert/cc/litert_options.h" // from @litert
4037
#include "litert/cc/litert_tensor_buffer.h" // from @litert
@@ -248,12 +245,12 @@ EmbeddingLookupText::Create(litert::Environment& env,
248245
const litert::Model* absl_nonnull model,
249246
std::optional<std::string> signature_key) {
250247
auto handler = std::unique_ptr<EmbeddingLookupText>(
251-
new EmbeddingLookupText(env, model, signature_key));
252-
RETURN_IF_ERROR(handler->Initialize());
248+
new EmbeddingLookupText(env, signature_key));
249+
RETURN_IF_ERROR(handler->Initialize(*model));
253250
return handler;
254251
}
255252

256-
absl::Status EmbeddingLookupText::Initialize() {
253+
absl::Status EmbeddingLookupText::Initialize(const litert::Model& model) {
257254
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
258255
#if defined(__ANDROID__)
259256
options.SetHardwareAccelerators(litert::HwAccelerators::kNpu |
@@ -271,8 +268,8 @@ absl::Status EmbeddingLookupText::Initialize() {
271268
#endif
272269

273270
LITERT_ASSIGN_OR_RETURN(compiled_model_, litert::CompiledModel::Create(
274-
env_, model_.Get(), options));
275-
LITERT_ASSIGN_OR_RETURN(auto signatures, model_.GetSignatures());
271+
env_, model.Get(), options));
272+
LITERT_ASSIGN_OR_RETURN(auto signatures, model.GetSignatures());
276273

277274
if (signature_key_.has_value()) {
278275
bool found = false;
@@ -354,4 +351,12 @@ absl::Status EmbeddingLookupText::Initialize() {
354351
return absl::OkStatus();
355352
}
356353

354+
litert::Expected<bool> EmbeddingLookupText::IsFullyAccelerated() {
355+
if (!compiled_model_.has_value()) {
356+
return litert::Error(litert::Status::kErrorRuntimeFailure,
357+
"Compiled model has not been created.");
358+
}
359+
return compiled_model_->IsFullyAccelerated();
360+
}
361+
357362
} // namespace litert::lm

runtime/components/embedding_lookup/embedding_lookup_text.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include <memory>
2323
#include <optional>
2424
#include <string>
25-
#include <utility>
2625
#include <vector>
2726

2827
#include "absl/base/nullability.h" // from @com_google_absl
@@ -31,6 +30,7 @@
3130
#include "absl/types/span.h" // from @com_google_absl
3231
#include "litert/cc/litert_compiled_model.h" // from @litert
3332
#include "litert/cc/litert_environment.h" // from @litert
33+
#include "litert/cc/litert_expected.h" // from @litert
3434
#include "litert/cc/litert_model.h" // from @litert
3535
#include "litert/cc/litert_options.h" // from @litert
3636
#include "litert/cc/litert_ranked_tensor_type.h" // from @litert
@@ -103,6 +103,8 @@ class EmbeddingLookupText : public EmbeddingLookup {
103103
// Returns number of floats per token in the output tensor.
104104
size_t GetFloatsPerToken();
105105

106+
litert::Expected<bool> IsFullyAccelerated() override;
107+
106108
// Returns the default embedding vector to use when a token is not found in
107109
// the lookup table.
108110
const std::vector<float>& GetDefaultEmbeddingVector() const {
@@ -116,22 +118,18 @@ class EmbeddingLookupText : public EmbeddingLookup {
116118

117119
protected:
118120
EmbeddingLookupText(litert::Environment& env,
119-
const litert::Model* absl_nonnull model,
120121
std::optional<std::string> signature_key)
121-
: env_(env), model_(*model), signature_key_(signature_key) {}
122+
: env_(env), signature_key_(signature_key) {}
122123

123124
// Loads the provided model. This must be called before Lookup.
124-
absl::Status Initialize();
125+
absl::Status Initialize(const litert::Model& model);
125126

126127
// Internal implementation of Lookup for both the single and multiple token
127128
// cases.
128129
absl::Status LookupInternal(int token, absl::Span<uint8_t> buffer);
129130

130131
// The environment for the embedding lookup.
131132
litert::Environment& env_;
132-
// The model for the embedding lookup. The actual model instance is owned by
133-
// the model resources.
134-
const litert::Model& model_;
135133
// The compiled model for the embedding model.
136134
std::optional<litert::CompiledModel> compiled_model_;
137135

runtime/components/model_resources.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ class ModelResources {
185185

186186
// Returns the llm metadata.
187187
virtual absl::StatusOr<const proto::LlmMetadata*> GetLlmMetadata() = 0;
188+
189+
// Releases the TFLite model from RAM. This is used to reduce peak memory
190+
// usage after the model has been compiled into a hardware-specific
191+
// executable.
192+
virtual absl::Status ReleaseTFLiteModel(ModelType model_type) = 0;
188193
};
189194

190195
} // namespace litert::lm

0 commit comments

Comments
 (0)