Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions runtime/components/model_resources.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ class ModelResources {
virtual std::optional<std::string> GetTFLiteModelBackendConstraint(
ModelType model_type) = 0;

// Returns a string-valued metadata entry attached to the TFLite model
// section. When the metadata key is not present, returns nullopt.
virtual std::optional<std::string> GetTFLiteModelMetadataValue(
ModelType model_type, absl::string_view key) = 0;

// Builds a tokenizer instance from the model and returns it.
virtual absl::StatusOr<std::unique_ptr<Tokenizer>> GetTokenizer() = 0;

Expand Down
5 changes: 5 additions & 0 deletions runtime/components/model_resources_litert_lm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ ModelResourcesLitertLm::GetTFLiteModelBackendConstraint(ModelType model_type) {
return litert_lm_loader_->GetTFLiteModelBackendConstraint(model_type);
}

std::optional<std::string> ModelResourcesLitertLm::GetTFLiteModelMetadataValue(
ModelType model_type, absl::string_view key) {
return litert_lm_loader_->GetTFLiteModelMetadataValue(model_type, key);
}

absl::StatusOr<absl::string_view> ModelResourcesLitertLm::GetTFLiteModelBuffer(
ModelType model_type) {
litert::BufferRef<uint8_t> buffer_ref =
Expand Down
3 changes: 3 additions & 0 deletions runtime/components/model_resources_litert_lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class ModelResourcesLitertLm : public ModelResources {
std::optional<std::string> GetTFLiteModelBackendConstraint(
ModelType model_type) override;

std::optional<std::string> GetTFLiteModelMetadataValue(
ModelType model_type, absl::string_view key) override;

// Returns the tokenizer from the *.litertlm file. If both SentencePiece and
// HuggingFace tokenizer are present and supported by the current build
// configuration, the SentencePiece tokenizer will be used.
Expand Down
7 changes: 7 additions & 0 deletions runtime/components/model_resources_streaming.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ ModelResourcesStreaming::GetTFLiteModelBackendConstraint(ModelType model_type) {
return std::nullopt;
}

std::optional<std::string> ModelResourcesStreaming::GetTFLiteModelMetadataValue(
ModelType model_type, absl::string_view key) {
(void)model_type;
(void)key;
return std::nullopt;
}

absl::StatusOr<std::unique_ptr<Tokenizer>>
ModelResourcesStreaming::GetTokenizer() {
return absl::UnimplementedError("Not implemented.");
Expand Down
3 changes: 3 additions & 0 deletions runtime/components/model_resources_streaming.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class ModelResourcesStreaming : public ModelResources {
std::optional<std::string> GetTFLiteModelBackendConstraint(
ModelType model_type) override;

std::optional<std::string> GetTFLiteModelMetadataValue(
ModelType model_type, absl::string_view key) override;

absl::StatusOr<std::unique_ptr<Tokenizer>> GetTokenizer() override;

absl::StatusOr<const proto::LlmMetadata*> GetLlmMetadata() override;
Expand Down
6 changes: 6 additions & 0 deletions runtime/components/model_resources_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ class ModelResourcesTask : public ModelResources {
// Task model does not support backend constraint.
return std::nullopt;
};
std::optional<std::string> GetTFLiteModelMetadataValue(
ModelType model_type, absl::string_view key) override {
(void)model_type;
(void)key;
return std::nullopt;
}
absl::StatusOr<std::unique_ptr<Tokenizer>> GetTokenizer() override;
absl::StatusOr<const proto::LlmMetadata*> GetLlmMetadata() override;
absl::StatusOr<std::reference_wrapper<ScopedFile>> GetScopedFile() override {
Expand Down
86 changes: 86 additions & 0 deletions runtime/executor/litert_compiled_model_executor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <iterator>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <variant>
Expand All @@ -31,7 +32,10 @@
#include "absl/log/absl_log.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/ascii.h" // from @com_google_absl
#include "absl/strings/match.h" // from @com_google_absl
#include "absl/strings/numbers.h" // from @com_google_absl
#include "absl/strings/str_split.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "litert/cc/litert_element_type.h" // from @litert
Expand Down Expand Up @@ -122,8 +126,90 @@ BuildModelResourcesFromLitertLmFormat(const ModelAssets& model_assets) {
return ModelResourcesLitertLm::Create(std::move(loader));
}

absl::StatusOr<int> GetCacheTensorLayerIndex(absl::string_view tensor_name,
absl::string_view root_name) {
if (!absl::StartsWith(tensor_name, root_name)) {
return absl::InvalidArgumentError("Tensor name does not match cache root.");
}
const absl::string_view suffix = tensor_name.substr(root_name.size());
int layer_index = 0;
if (!absl::SimpleAtoi(suffix, &layer_index) || layer_index < 0) {
return absl::InvalidArgumentError(
"Failed to parse cache tensor layer index.");
}
return layer_index;
}

} // namespace

absl::StatusOr<CacheAdapterMetadata> ParseCacheAdapterMetadata(
std::optional<std::string> cache_adapter_kind,
std::optional<std::string> cache_layer_types) {
CacheAdapterMetadata metadata;
if (!cache_adapter_kind.has_value() || cache_adapter_kind->empty()) {
return metadata;
}
metadata.cache_adapter_kind =
absl::AsciiStrToLower(absl::StripAsciiWhitespace(*cache_adapter_kind));
if (metadata.cache_adapter_kind.empty()) {
return metadata;
}
if (metadata.cache_adapter_kind == "qwen35") {
RET_CHECK(cache_layer_types.has_value() && !cache_layer_types->empty())
.SetCode(absl::StatusCode::kFailedPrecondition)
<< "cache_layer_types metadata is required for qwen35 cache adapter.";
for (absl::string_view part : absl::StrSplit(*cache_layer_types, ',')) {
const std::string layer_type =
std::string(absl::AsciiStrToLower(absl::StripAsciiWhitespace(part)));
if (!layer_type.empty()) {
metadata.cache_layer_types.push_back(layer_type);
}
}
RET_CHECK(!metadata.cache_layer_types.empty())
.SetCode(absl::StatusCode::kFailedPrecondition)
<< "cache_layer_types metadata is empty for qwen35 cache adapter.";
}
return metadata;
}

absl::StatusOr<CacheAdapterMetadata> GetCacheAdapterMetadata(
ModelResources& resources, ModelType model_type) {
return ParseCacheAdapterMetadata(
resources.GetTFLiteModelMetadataValue(model_type, "cache_adapter_kind"),
resources.GetTFLiteModelMetadataValue(model_type, "cache_layer_types"));
}

absl::StatusOr<bool> IsSequenceCacheTensorName(
absl::string_view tensor_name, absl::string_view k_root_name,
absl::string_view v_root_name, const CacheAdapterMetadata& metadata) {
if (!metadata.has_adapter()) {
return true;
}
if (metadata.cache_adapter_kind != "qwen35") {
return true;
}
const bool is_key_tensor = absl::StartsWith(tensor_name, k_root_name);
const bool is_value_tensor = absl::StartsWith(tensor_name, v_root_name);
RET_CHECK(is_key_tensor || is_value_tensor)
.SetCode(absl::StatusCode::kInvalidArgument)
<< "Tensor name is not a recognized KV cache tensor: " << tensor_name;
const absl::string_view root_name = is_key_tensor ? k_root_name : v_root_name;
ASSIGN_OR_RETURN(int layer_index,
GetCacheTensorLayerIndex(tensor_name, root_name));
RET_CHECK_LT(layer_index, metadata.cache_layer_types.size())
.SetCode(absl::StatusCode::kFailedPrecondition)
<< "Cache metadata layer_types does not cover tensor: " << tensor_name;
const absl::string_view layer_type = metadata.cache_layer_types[layer_index];
if (layer_type == "full_attention") {
return true;
}
if (layer_type == "linear_attention") {
return false;
}
return absl::FailedPreconditionError(
"Unsupported qwen35 cache layer type in metadata.");
}

absl::StatusOr<ModelSignatures> GetModelSignaturesFromInputOutputNames(
const std::vector<absl::string_view>& input_names,
const std::vector<absl::string_view>& output_names, bool strict) {
Expand Down
19 changes: 19 additions & 0 deletions runtime/executor/litert_compiled_model_executor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,25 @@ struct ModelSignatures {
std::string output_logits;
};

struct CacheAdapterMetadata {
std::string cache_adapter_kind;
std::vector<std::string> cache_layer_types;

bool has_adapter() const { return !cache_adapter_kind.empty(); }
};

absl::StatusOr<CacheAdapterMetadata> ParseCacheAdapterMetadata(
std::optional<std::string> cache_adapter_kind,
std::optional<std::string> cache_layer_types);

absl::StatusOr<CacheAdapterMetadata> GetCacheAdapterMetadata(
ModelResources& resources,
ModelType model_type = ModelType::kTfLitePrefillDecode);

absl::StatusOr<bool> IsSequenceCacheTensorName(
absl::string_view tensor_name, absl::string_view k_root_name,
absl::string_view v_root_name, const CacheAdapterMetadata& metadata);

// Get the corresponding ModelSignatures struct for the given model using
// the signature runner. Returns an error if the runner's signature does not
// match any of the predefined signature set.
Expand Down
60 changes: 60 additions & 0 deletions runtime/executor/litert_compiled_model_executor_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <fstream>
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <variant>
Expand Down Expand Up @@ -195,6 +196,65 @@ TEST(LlmLiteRTCompiledModelExecutorUtilsTest, GetKVCacheRootNames_KvCacheC) {
EXPECT_EQ(v_root_name, "kv_cache_c_");
}

TEST(LlmLiteRTCompiledModelExecutorUtilsTest,
ParseCacheAdapterMetadata_Qwen35) {
ASSERT_OK_AND_ASSIGN(
auto metadata,
ParseCacheAdapterMetadata(
std::make_optional<std::string>("qwen35"),
std::make_optional<std::string>(
"linear_attention, linear_attention, full_attention")));
EXPECT_EQ(metadata.cache_adapter_kind, "qwen35");
EXPECT_THAT(
metadata.cache_layer_types,
ElementsAre("linear_attention", "linear_attention", "full_attention"));
}

TEST(LlmLiteRTCompiledModelExecutorUtilsTest,
IsSequenceCacheTensorName_DefaultsToSequenceCache) {
CacheAdapterMetadata metadata;
ASSERT_OK_AND_ASSIGN(bool is_sequence,
IsSequenceCacheTensorName("kv_cache_k_0", "kv_cache_k_",
"kv_cache_v_", metadata));
EXPECT_TRUE(is_sequence);
}

TEST(LlmLiteRTCompiledModelExecutorUtilsTest,
IsSequenceCacheTensorName_Qwen35LinearAttention) {
ASSERT_OK_AND_ASSIGN(
auto metadata,
ParseCacheAdapterMetadata(
std::make_optional<std::string>("qwen35"),
std::make_optional<std::string>(
"linear_attention,linear_attention,full_attention")));
ASSERT_OK_AND_ASSIGN(bool key_is_sequence,
IsSequenceCacheTensorName("kv_cache_k_1", "kv_cache_k_",
"kv_cache_v_", metadata));
ASSERT_OK_AND_ASSIGN(bool value_is_sequence,
IsSequenceCacheTensorName("kv_cache_v_1", "kv_cache_k_",
"kv_cache_v_", metadata));
EXPECT_FALSE(key_is_sequence);
EXPECT_FALSE(value_is_sequence);
}

TEST(LlmLiteRTCompiledModelExecutorUtilsTest,
IsSequenceCacheTensorName_Qwen35FullAttention) {
ASSERT_OK_AND_ASSIGN(
auto metadata,
ParseCacheAdapterMetadata(
std::make_optional<std::string>("qwen35"),
std::make_optional<std::string>(
"linear_attention,linear_attention,full_attention")));
ASSERT_OK_AND_ASSIGN(bool key_is_sequence,
IsSequenceCacheTensorName("kv_cache_k_2", "kv_cache_k_",
"kv_cache_v_", metadata));
ASSERT_OK_AND_ASSIGN(bool value_is_sequence,
IsSequenceCacheTensorName("kv_cache_v_2", "kv_cache_k_",
"kv_cache_v_", metadata));
EXPECT_TRUE(key_is_sequence);
EXPECT_TRUE(value_is_sequence);
}

TEST(LlmLiteRTCompiledModelExecutorUtilsTest,
FillSingleBufferCacheParamTensor) {
LITERT_ASSERT_OK_AND_ASSIGN(auto env, ::litert::Environment::Create({}));
Expand Down
Loading
Loading