Skip to content

Commit 9dd9ac0

Browse files
snnncopybara-github
authored andcommitted
[WIP] Add linear attention cache type support
LiteRT-LM-PiperOrigin-RevId: 907204200
1 parent 09016c5 commit 9dd9ac0

20 files changed

Lines changed: 469 additions & 94 deletions

runtime/components/model_resources.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ class ModelResources {
174174
virtual std::optional<std::string> GetTFLiteModelBackendConstraint(
175175
ModelType model_type) = 0;
176176

177+
// Returns a string-valued metadata entry attached to the TFLite model
178+
// section. When the metadata key is not present, returns nullopt.
179+
virtual std::optional<std::string> GetTFLiteModelMetadataValue(
180+
ModelType model_type, absl::string_view key) = 0;
181+
177182
// Builds a tokenizer instance from the model and returns it.
178183
virtual absl::StatusOr<std::unique_ptr<Tokenizer>> GetTokenizer() = 0;
179184

runtime/components/model_resources_litert_lm.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ ModelResourcesLitertLm::GetTFLiteModelBackendConstraint(ModelType model_type) {
8080
return litert_lm_loader_->GetTFLiteModelBackendConstraint(model_type);
8181
}
8282

83+
std::optional<std::string> ModelResourcesLitertLm::GetTFLiteModelMetadataValue(
84+
ModelType model_type, absl::string_view key) {
85+
return litert_lm_loader_->GetTFLiteModelMetadataValue(model_type, key);
86+
}
87+
8388
absl::StatusOr<absl::string_view> ModelResourcesLitertLm::GetTFLiteModelBuffer(
8489
ModelType model_type) {
8590
litert::BufferRef<uint8_t> buffer_ref =

runtime/components/model_resources_litert_lm.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class ModelResourcesLitertLm : public ModelResources {
4949
std::optional<std::string> GetTFLiteModelBackendConstraint(
5050
ModelType model_type) override;
5151

52+
std::optional<std::string> GetTFLiteModelMetadataValue(
53+
ModelType model_type, absl::string_view key) override;
54+
5255
// Returns the tokenizer from the *.litertlm file. If both SentencePiece and
5356
// HuggingFace tokenizer are present and supported by the current build
5457
// configuration, the SentencePiece tokenizer will be used.

runtime/components/model_resources_streaming.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ ModelResourcesStreaming::GetTFLiteModelBackendConstraint(ModelType model_type) {
5757
return std::nullopt;
5858
}
5959

60+
std::optional<std::string> ModelResourcesStreaming::GetTFLiteModelMetadataValue(
61+
ModelType model_type, absl::string_view key) {
62+
(void)model_type;
63+
(void)key;
64+
return std::nullopt;
65+
}
66+
6067
absl::StatusOr<std::unique_ptr<Tokenizer>>
6168
ModelResourcesStreaming::GetTokenizer() {
6269
return absl::UnimplementedError("Not implemented.");

runtime/components/model_resources_streaming.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ class ModelResourcesStreaming : public ModelResources {
5454
std::optional<std::string> GetTFLiteModelBackendConstraint(
5555
ModelType model_type) override;
5656

57+
std::optional<std::string> GetTFLiteModelMetadataValue(
58+
ModelType model_type, absl::string_view key) override;
59+
5760
absl::StatusOr<std::unique_ptr<Tokenizer>> GetTokenizer() override;
5861

5962
absl::StatusOr<const proto::LlmMetadata*> GetLlmMetadata() override;

runtime/components/model_resources_task.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ class ModelResourcesTask : public ModelResources {
5151
// Task model does not support backend constraint.
5252
return std::nullopt;
5353
};
54+
std::optional<std::string> GetTFLiteModelMetadataValue(
55+
ModelType model_type, absl::string_view key) override {
56+
(void)model_type;
57+
(void)key;
58+
return std::nullopt;
59+
}
5460
absl::StatusOr<std::unique_ptr<Tokenizer>> GetTokenizer() override;
5561
absl::StatusOr<const proto::LlmMetadata*> GetLlmMetadata() override;
5662
absl::StatusOr<std::reference_wrapper<ScopedFile>> GetScopedFile() override {

runtime/executor/litert_compiled_model_executor_utils.cc

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <iterator>
2323
#include <limits>
2424
#include <memory>
25+
#include <optional>
2526
#include <string>
2627
#include <utility>
2728
#include <variant>
@@ -31,7 +32,10 @@
3132
#include "absl/log/absl_log.h" // from @com_google_absl
3233
#include "absl/status/status.h" // from @com_google_absl
3334
#include "absl/status/statusor.h" // from @com_google_absl
35+
#include "absl/strings/ascii.h" // from @com_google_absl
3436
#include "absl/strings/match.h" // from @com_google_absl
37+
#include "absl/strings/numbers.h" // from @com_google_absl
38+
#include "absl/strings/str_split.h" // from @com_google_absl
3539
#include "absl/strings/string_view.h" // from @com_google_absl
3640
#include "absl/types/span.h" // from @com_google_absl
3741
#include "litert/cc/litert_element_type.h" // from @litert
@@ -122,8 +126,90 @@ BuildModelResourcesFromLitertLmFormat(const ModelAssets& model_assets) {
122126
return ModelResourcesLitertLm::Create(std::move(loader));
123127
}
124128

129+
absl::StatusOr<int> GetCacheTensorLayerIndex(absl::string_view tensor_name,
130+
absl::string_view root_name) {
131+
if (!absl::StartsWith(tensor_name, root_name)) {
132+
return absl::InvalidArgumentError("Tensor name does not match cache root.");
133+
}
134+
const absl::string_view suffix = tensor_name.substr(root_name.size());
135+
int layer_index = 0;
136+
if (!absl::SimpleAtoi(suffix, &layer_index) || layer_index < 0) {
137+
return absl::InvalidArgumentError(
138+
"Failed to parse cache tensor layer index.");
139+
}
140+
return layer_index;
141+
}
142+
125143
} // namespace
126144

145+
absl::StatusOr<CacheAdapterMetadata> ParseCacheAdapterMetadata(
146+
std::optional<std::string> cache_adapter_kind,
147+
std::optional<std::string> cache_layer_types) {
148+
CacheAdapterMetadata metadata;
149+
if (!cache_adapter_kind.has_value() || cache_adapter_kind->empty()) {
150+
return metadata;
151+
}
152+
metadata.cache_adapter_kind =
153+
absl::AsciiStrToLower(absl::StripAsciiWhitespace(*cache_adapter_kind));
154+
if (metadata.cache_adapter_kind.empty()) {
155+
return metadata;
156+
}
157+
if (metadata.cache_adapter_kind == "qwen35") {
158+
RET_CHECK(cache_layer_types.has_value() && !cache_layer_types->empty())
159+
.SetCode(absl::StatusCode::kFailedPrecondition)
160+
<< "cache_layer_types metadata is required for qwen35 cache adapter.";
161+
for (absl::string_view part : absl::StrSplit(*cache_layer_types, ',')) {
162+
const std::string layer_type =
163+
std::string(absl::AsciiStrToLower(absl::StripAsciiWhitespace(part)));
164+
if (!layer_type.empty()) {
165+
metadata.cache_layer_types.push_back(layer_type);
166+
}
167+
}
168+
RET_CHECK(!metadata.cache_layer_types.empty())
169+
.SetCode(absl::StatusCode::kFailedPrecondition)
170+
<< "cache_layer_types metadata is empty for qwen35 cache adapter.";
171+
}
172+
return metadata;
173+
}
174+
175+
absl::StatusOr<CacheAdapterMetadata> GetCacheAdapterMetadata(
176+
ModelResources& resources, ModelType model_type) {
177+
return ParseCacheAdapterMetadata(
178+
resources.GetTFLiteModelMetadataValue(model_type, "cache_adapter_kind"),
179+
resources.GetTFLiteModelMetadataValue(model_type, "cache_layer_types"));
180+
}
181+
182+
absl::StatusOr<bool> IsSequenceCacheTensorName(
183+
absl::string_view tensor_name, absl::string_view k_root_name,
184+
absl::string_view v_root_name, const CacheAdapterMetadata& metadata) {
185+
if (!metadata.has_adapter()) {
186+
return true;
187+
}
188+
if (metadata.cache_adapter_kind != "qwen35") {
189+
return true;
190+
}
191+
const bool is_key_tensor = absl::StartsWith(tensor_name, k_root_name);
192+
const bool is_value_tensor = absl::StartsWith(tensor_name, v_root_name);
193+
RET_CHECK(is_key_tensor || is_value_tensor)
194+
.SetCode(absl::StatusCode::kInvalidArgument)
195+
<< "Tensor name is not a recognized KV cache tensor: " << tensor_name;
196+
const absl::string_view root_name = is_key_tensor ? k_root_name : v_root_name;
197+
ASSIGN_OR_RETURN(int layer_index,
198+
GetCacheTensorLayerIndex(tensor_name, root_name));
199+
RET_CHECK_LT(layer_index, metadata.cache_layer_types.size())
200+
.SetCode(absl::StatusCode::kFailedPrecondition)
201+
<< "Cache metadata layer_types does not cover tensor: " << tensor_name;
202+
const absl::string_view layer_type = metadata.cache_layer_types[layer_index];
203+
if (layer_type == "full_attention") {
204+
return true;
205+
}
206+
if (layer_type == "linear_attention") {
207+
return false;
208+
}
209+
return absl::FailedPreconditionError(
210+
"Unsupported qwen35 cache layer type in metadata.");
211+
}
212+
127213
absl::StatusOr<ModelSignatures> GetModelSignaturesFromInputOutputNames(
128214
const std::vector<absl::string_view>& input_names,
129215
const std::vector<absl::string_view>& output_names, bool strict) {

runtime/executor/litert_compiled_model_executor_utils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,25 @@ struct ModelSignatures {
7474
std::string output_logits;
7575
};
7676

77+
struct CacheAdapterMetadata {
78+
std::string cache_adapter_kind;
79+
std::vector<std::string> cache_layer_types;
80+
81+
bool has_adapter() const { return !cache_adapter_kind.empty(); }
82+
};
83+
84+
absl::StatusOr<CacheAdapterMetadata> ParseCacheAdapterMetadata(
85+
std::optional<std::string> cache_adapter_kind,
86+
std::optional<std::string> cache_layer_types);
87+
88+
absl::StatusOr<CacheAdapterMetadata> GetCacheAdapterMetadata(
89+
ModelResources& resources,
90+
ModelType model_type = ModelType::kTfLitePrefillDecode);
91+
92+
absl::StatusOr<bool> IsSequenceCacheTensorName(
93+
absl::string_view tensor_name, absl::string_view k_root_name,
94+
absl::string_view v_root_name, const CacheAdapterMetadata& metadata);
95+
7796
// Get the corresponding ModelSignatures struct for the given model using
7897
// the signature runner. Returns an error if the runner's signature does not
7998
// match any of the predefined signature set.

runtime/executor/litert_compiled_model_executor_utils_test.cc

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <fstream>
1919
#include <limits>
2020
#include <memory>
21+
#include <optional>
2122
#include <string>
2223
#include <utility>
2324
#include <variant>
@@ -195,6 +196,65 @@ TEST(LlmLiteRTCompiledModelExecutorUtilsTest, GetKVCacheRootNames_KvCacheC) {
195196
EXPECT_EQ(v_root_name, "kv_cache_c_");
196197
}
197198

199+
TEST(LlmLiteRTCompiledModelExecutorUtilsTest,
200+
ParseCacheAdapterMetadata_Qwen35) {
201+
ASSERT_OK_AND_ASSIGN(
202+
auto metadata,
203+
ParseCacheAdapterMetadata(
204+
std::make_optional<std::string>("qwen35"),
205+
std::make_optional<std::string>(
206+
"linear_attention, linear_attention, full_attention")));
207+
EXPECT_EQ(metadata.cache_adapter_kind, "qwen35");
208+
EXPECT_THAT(
209+
metadata.cache_layer_types,
210+
ElementsAre("linear_attention", "linear_attention", "full_attention"));
211+
}
212+
213+
TEST(LlmLiteRTCompiledModelExecutorUtilsTest,
214+
IsSequenceCacheTensorName_DefaultsToSequenceCache) {
215+
CacheAdapterMetadata metadata;
216+
ASSERT_OK_AND_ASSIGN(bool is_sequence,
217+
IsSequenceCacheTensorName("kv_cache_k_0", "kv_cache_k_",
218+
"kv_cache_v_", metadata));
219+
EXPECT_TRUE(is_sequence);
220+
}
221+
222+
TEST(LlmLiteRTCompiledModelExecutorUtilsTest,
223+
IsSequenceCacheTensorName_Qwen35LinearAttention) {
224+
ASSERT_OK_AND_ASSIGN(
225+
auto metadata,
226+
ParseCacheAdapterMetadata(
227+
std::make_optional<std::string>("qwen35"),
228+
std::make_optional<std::string>(
229+
"linear_attention,linear_attention,full_attention")));
230+
ASSERT_OK_AND_ASSIGN(bool key_is_sequence,
231+
IsSequenceCacheTensorName("kv_cache_k_1", "kv_cache_k_",
232+
"kv_cache_v_", metadata));
233+
ASSERT_OK_AND_ASSIGN(bool value_is_sequence,
234+
IsSequenceCacheTensorName("kv_cache_v_1", "kv_cache_k_",
235+
"kv_cache_v_", metadata));
236+
EXPECT_FALSE(key_is_sequence);
237+
EXPECT_FALSE(value_is_sequence);
238+
}
239+
240+
TEST(LlmLiteRTCompiledModelExecutorUtilsTest,
241+
IsSequenceCacheTensorName_Qwen35FullAttention) {
242+
ASSERT_OK_AND_ASSIGN(
243+
auto metadata,
244+
ParseCacheAdapterMetadata(
245+
std::make_optional<std::string>("qwen35"),
246+
std::make_optional<std::string>(
247+
"linear_attention,linear_attention,full_attention")));
248+
ASSERT_OK_AND_ASSIGN(bool key_is_sequence,
249+
IsSequenceCacheTensorName("kv_cache_k_2", "kv_cache_k_",
250+
"kv_cache_v_", metadata));
251+
ASSERT_OK_AND_ASSIGN(bool value_is_sequence,
252+
IsSequenceCacheTensorName("kv_cache_v_2", "kv_cache_k_",
253+
"kv_cache_v_", metadata));
254+
EXPECT_TRUE(key_is_sequence);
255+
EXPECT_TRUE(value_is_sequence);
256+
}
257+
198258
TEST(LlmLiteRTCompiledModelExecutorUtilsTest,
199259
FillSingleBufferCacheParamTensor) {
200260
LITERT_ASSERT_OK_AND_ASSIGN(auto env, ::litert::Environment::Create({}));

0 commit comments

Comments
 (0)