Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions runtime/components/embedding_lookup/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ cc_library(
deps = [
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:span",
"@litert//litert/cc:litert_expected",
] + select({
"@litert//litert:litert_link_capi_so": [
"@litert//litert/cc:litert_api_with_dynamic_runtime",
Expand All @@ -51,6 +52,8 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@litert//litert/c:litert_common",
"@litert//litert/cc:litert_expected",
"//runtime/util:litert_status_util",
] + select({
"@litert//litert:litert_link_capi_so": [
Expand Down Expand Up @@ -152,6 +155,8 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@litert//litert/c:litert_common",
"@litert//litert/cc:litert_expected",
"@litert//litert/cc:litert_macros",
"//runtime/executor:llm_executor_io_types",
"//runtime/util:litert_status_util",
Expand Down
4 changes: 4 additions & 0 deletions runtime/components/embedding_lookup/embedding_lookup.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "absl/status/status.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "litert/cc/litert_expected.h" // from @litert
#include "litert/cc/litert_tensor_buffer.h" // from @litert

namespace litert::lm {
Expand Down Expand Up @@ -69,6 +70,9 @@ class EmbeddingLookup {
virtual absl::Status LookupPrefill(absl::Span<const int> tokens,
litert::TensorBuffer* output_tensor,
size_t byte_offset) = 0;

// Returns whether the embedding lookup compiled model is fully accelerated.
virtual litert::Expected<bool> IsFullyAccelerated() = 0;
};

} // namespace litert::lm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ EndOfMultiModalEmbedding::Create(litert::Environment& env,
const litert::Model* absl_nonnull model,
int special_token) {
auto handler = std::unique_ptr<EndOfMultiModalEmbedding>(
new EndOfMultiModalEmbedding(env, model, special_token));
new EndOfMultiModalEmbedding(env, special_token));
RETURN_IF_ERROR( // IWYU pragma: keep as is included by status_macros.h
handler->Initialize());
handler->Initialize(*model));
return handler;
}

absl::Status EndOfMultiModalEmbedding::Initialize() {
absl::Status EndOfMultiModalEmbedding::Initialize(const Model& model) {
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
#if defined(__ANDROID__)
options.SetHardwareAccelerators(litert::HwAccelerators::kNpu |
Expand All @@ -179,8 +179,8 @@ absl::Status EndOfMultiModalEmbedding::Initialize() {

LITERT_ASSIGN_OR_RETURN(
litert::CompiledModel compiled_model,
litert::CompiledModel::Create(env_, model_.Get(), options));
if (auto num_signatures = model_.GetNumSignatures(); num_signatures != 1) {
litert::CompiledModel::Create(env_, model.Get(), options));
if (auto num_signatures = model.GetNumSignatures(); num_signatures != 1) {
return absl::InvalidArgumentError(absl::StrCat(
"The Embedding model must have exactly one signature but got ",
num_signatures));
Expand Down Expand Up @@ -237,6 +237,10 @@ absl::Status EndOfMultiModalEmbedding::Initialize() {
size_t bytes = end_of_multi_modal_embedding_.size() * sizeof(float);
output_buffers[0].Read(absl::MakeSpan(data_ptr, bytes));

if (auto res = compiled_model.IsFullyAccelerated(); res.HasValue()) {
is_fully_accelerated_ = *res;
}

return absl::OkStatus();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

#include <cstddef>
#include <memory>
#include <utility>
#include <vector>

#include "absl/base/nullability.h" // from @com_google_absl
Expand All @@ -29,7 +28,6 @@
#include "litert/cc/litert_environment.h" // from @litert
#include "litert/cc/litert_layout.h" // from @litert
#include "litert/cc/litert_model.h" // from @litert
#include "litert/cc/litert_options.h" // from @litert
#include "litert/cc/litert_tensor_buffer.h" // from @litert
#include "runtime/components/embedding_lookup/embedding_lookup.h"

Expand Down Expand Up @@ -74,20 +72,19 @@ class EndOfMultiModalEmbedding : public EmbeddingLookup {
litert::TensorBuffer* prefill_output,
size_t byte_offset) override;

litert::Expected<bool> IsFullyAccelerated() override {
return is_fully_accelerated_;
}

protected:
EndOfMultiModalEmbedding(litert::Environment& env,
const litert::Model* absl_nonnull model,
int special_token)
: env_(env), model_(*model), special_token_(special_token) {}
EndOfMultiModalEmbedding(litert::Environment& env, int special_token)
: env_(env), special_token_(special_token) {}

// Loads the provided model. This must be called before Lookup functions.
absl::Status Initialize();
absl::Status Initialize(const litert::Model& model);

// The environment for the embedding lookup.
litert::Environment& env_;
// The model for the embedding lookup. The actual model instance is owned by
// the model resources.
const litert::Model& model_;

// The layout of the output tensor from the embedding model.
litert::Layout output_buffer_layout_;
Expand All @@ -99,6 +96,8 @@ class EndOfMultiModalEmbedding : public EmbeddingLookup {
// Contains the end of multi-modal embedding that was looked up from the
// model.
std::vector<float> end_of_multi_modal_embedding_;

bool is_fully_accelerated_ = false;
};

} // namespace litert::lm
Expand Down
19 changes: 19 additions & 0 deletions runtime/components/embedding_lookup/embedding_lookup_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "litert/c/litert_common.h" // from @litert
#include "litert/cc/litert_environment.h" // from @litert
#include "litert/cc/litert_expected.h" // from @litert
#include "litert/cc/litert_macros.h" // from @litert
#include "litert/cc/litert_model.h" // from @litert
#include "litert/cc/litert_tensor_buffer.h" // from @litert
Expand Down Expand Up @@ -265,4 +267,21 @@ absl::Status EmbeddingLookupManager::Initialize(
return absl::OkStatus();
}

litert::Expected<bool> EmbeddingLookupManager::IsFullyAccelerated() const {
if (text_embedding_lookup_ == nullptr) {
return litert::Error(kLiteRtStatusErrorRuntimeFailure,
"Text embedding lookup has not been created.");
}
if (auto res = text_embedding_lookup_->IsFullyAccelerated();
!res.HasValue() || !*res) {
return res;
}
for (const auto& lookup : end_of_multi_modal_embedding_lookups_) {
if (auto res = lookup->IsFullyAccelerated(); !res.HasValue() || !*res) {
return res;
}
}
return true;
}

} // namespace litert::lm
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "litert/cc/litert_environment.h" // from @litert
#include "litert/cc/litert_expected.h" // from @litert
#include "litert/cc/litert_model.h" // from @litert
#include "litert/cc/litert_tensor_buffer.h" // from @litert
#include "runtime/components/embedding_lookup/embedding_lookup_end_of_multi_modal.h"
Expand Down Expand Up @@ -116,6 +117,8 @@ class EmbeddingLookupManager {
return text_embedding_lookup_.get();
}

litert::Expected<bool> IsFullyAccelerated() const;

protected:
absl::Status Initialize(
litert::Environment& env,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class EmbeddingLookupMultiModal : public EmbeddingLookup {
// Returns true if there are any embeddings left to be read.
bool HasRemainingEmbeddings() const { return embedding_.size() > 0; }

litert::Expected<bool> IsFullyAccelerated() override { return true; }

protected:
absl::Status Initialize(const ::litert::TensorBuffer* embedding_buffer,
int special_token);
Expand Down
23 changes: 14 additions & 9 deletions runtime/components/embedding_lookup/embedding_lookup_text.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "litert/cc/litert_common.h" // from @litert
#include "litert/cc/litert_compiled_model.h" // from @litert
#include "litert/cc/litert_element_type.h" // from @litert
#include "litert/cc/litert_environment.h" // from @litert
#include "litert/cc/litert_macros.h" // from @litert
#include "litert/cc/litert_expected.h" // from @litert
#include "litert/cc/litert_model.h" // from @litert
#include "litert/cc/litert_options.h" // from @litert
#include "litert/cc/litert_tensor_buffer.h" // from @litert
Expand Down Expand Up @@ -248,12 +245,12 @@ EmbeddingLookupText::Create(litert::Environment& env,
const litert::Model* absl_nonnull model,
std::optional<std::string> signature_key) {
auto handler = std::unique_ptr<EmbeddingLookupText>(
new EmbeddingLookupText(env, model, signature_key));
RETURN_IF_ERROR(handler->Initialize());
new EmbeddingLookupText(env, signature_key));
RETURN_IF_ERROR(handler->Initialize(*model));
return handler;
}

absl::Status EmbeddingLookupText::Initialize() {
absl::Status EmbeddingLookupText::Initialize(const litert::Model& model) {
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
#if defined(__ANDROID__)
options.SetHardwareAccelerators(litert::HwAccelerators::kNpu |
Expand All @@ -271,8 +268,8 @@ absl::Status EmbeddingLookupText::Initialize() {
#endif

LITERT_ASSIGN_OR_RETURN(compiled_model_, litert::CompiledModel::Create(
env_, model_.Get(), options));
LITERT_ASSIGN_OR_RETURN(auto signatures, model_.GetSignatures());
env_, model.Get(), options));
LITERT_ASSIGN_OR_RETURN(auto signatures, model.GetSignatures());

if (signature_key_.has_value()) {
bool found = false;
Expand Down Expand Up @@ -354,4 +351,12 @@ absl::Status EmbeddingLookupText::Initialize() {
return absl::OkStatus();
}

litert::Expected<bool> EmbeddingLookupText::IsFullyAccelerated() {
if (!compiled_model_.has_value()) {
return litert::Error(litert::Status::kErrorRuntimeFailure,
"Compiled model has not been created.");
}
return compiled_model_->IsFullyAccelerated();
}

} // namespace litert::lm
12 changes: 5 additions & 7 deletions runtime/components/embedding_lookup/embedding_lookup_text.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/base/nullability.h" // from @com_google_absl
Expand All @@ -31,6 +30,7 @@
#include "absl/types/span.h" // from @com_google_absl
#include "litert/cc/litert_compiled_model.h" // from @litert
#include "litert/cc/litert_environment.h" // from @litert
#include "litert/cc/litert_expected.h" // from @litert
#include "litert/cc/litert_model.h" // from @litert
#include "litert/cc/litert_options.h" // from @litert
#include "litert/cc/litert_ranked_tensor_type.h" // from @litert
Expand Down Expand Up @@ -103,6 +103,8 @@ class EmbeddingLookupText : public EmbeddingLookup {
// Returns number of floats per token in the output tensor.
size_t GetFloatsPerToken();

litert::Expected<bool> IsFullyAccelerated() override;

// Returns the default embedding vector to use when a token is not found in
// the lookup table.
const std::vector<float>& GetDefaultEmbeddingVector() const {
Expand All @@ -116,22 +118,18 @@ class EmbeddingLookupText : public EmbeddingLookup {

protected:
EmbeddingLookupText(litert::Environment& env,
const litert::Model* absl_nonnull model,
std::optional<std::string> signature_key)
: env_(env), model_(*model), signature_key_(signature_key) {}
: env_(env), signature_key_(signature_key) {}

// Loads the provided model. This must be called before Lookup.
absl::Status Initialize();
absl::Status Initialize(const litert::Model& model);

// Internal implementation of Lookup for both the single and multiple token
// cases.
absl::Status LookupInternal(int token, absl::Span<uint8_t> buffer);

// The environment for the embedding lookup.
litert::Environment& env_;
// The model for the embedding lookup. The actual model instance is owned by
// the model resources.
const litert::Model& model_;
// The compiled model for the embedding model.
std::optional<litert::CompiledModel> compiled_model_;

Expand Down
5 changes: 5 additions & 0 deletions runtime/components/model_resources.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ class ModelResources {

// Returns the llm metadata.
virtual absl::StatusOr<const proto::LlmMetadata*> GetLlmMetadata() = 0;

// Releases the TFLite model from RAM. This is used to reduce peak memory
// usage after the model has been compiled into a hardware-specific
// executable.
virtual absl::Status ReleaseTFLiteModel(ModelType model_type) = 0;
};

} // namespace litert::lm
Expand Down
10 changes: 10 additions & 0 deletions runtime/components/model_resources_litert_lm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,14 @@ ModelResourcesLitertLm::GetWeightsSectionOffset(ModelType model_type) {
BufferKey(schema::AnySectionDataType_TFLiteWeights, model_type));
}

absl::Status ModelResourcesLitertLm::ReleaseTFLiteModel(ModelType model_type) {
model_map_.erase(model_type);
RETURN_IF_ERROR(litert_lm_loader_->ReleaseSection(
BufferKey(schema::AnySectionDataType_TFLiteModel, model_type)));
RETURN_IF_ERROR(litert_lm_loader_->ReleaseSection(
BufferKey(schema::AnySectionDataType_TFLiteWeights, model_type)));

return absl::OkStatus();
}

} // namespace litert::lm
3 changes: 3 additions & 0 deletions runtime/components/model_resources_litert_lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <utility>

#include "absl/container/flat_hash_map.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "litert/cc/litert_model.h" // from @litert
Expand Down Expand Up @@ -64,6 +65,8 @@ class ModelResourcesLitertLm : public ModelResources {
absl::StatusOr<std::pair<size_t, size_t>> GetWeightsSectionOffset(
ModelType model_type) override;

absl::Status ReleaseTFLiteModel(ModelType model_type) override;

protected:
explicit ModelResourcesLitertLm(
std::unique_ptr<LitertLmLoader> litert_lm_loader)
Expand Down
4 changes: 4 additions & 0 deletions runtime/components/model_resources_streaming.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,8 @@ ModelResourcesStreaming::GetLlmMetadata() {
return absl::UnimplementedError("Not implemented.");
}

absl::Status ModelResourcesStreaming::ReleaseTFLiteModel(ModelType model_type) {
return absl::UnimplementedError("Not implemented.");
}

} // namespace litert::lm
3 changes: 3 additions & 0 deletions runtime/components/model_resources_streaming.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <string>
#include <utility>

#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "litert/cc/litert_model.h" // from @litert
Expand Down Expand Up @@ -60,6 +61,8 @@ class ModelResourcesStreaming : public ModelResources {
absl::StatusOr<std::unique_ptr<Tokenizer>> GetTokenizer() override;

absl::StatusOr<const proto::LlmMetadata*> GetLlmMetadata() override;

absl::Status ReleaseTFLiteModel(ModelType model_type) override;
};

} // namespace litert::lm
Expand Down
5 changes: 5 additions & 0 deletions runtime/components/model_resources_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ absl::StatusOr<const litert::Model*> ModelResourcesTask::GetTFLiteModel(
return model_map_[model_type].get();
}

absl::Status ModelResourcesTask::ReleaseTFLiteModel(ModelType model_type) {
model_map_.erase(model_type);
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<Tokenizer>> ModelResourcesTask::GetTokenizer() {
ASSIGN_OR_RETURN(auto string_view,
model_asset_bundle_resources_->GetFile("TOKENIZER_MODEL"));
Expand Down
Loading
Loading