|
22 | 22 | #include <iterator> |
23 | 23 | #include <limits> |
24 | 24 | #include <memory> |
| 25 | +#include <optional> |
25 | 26 | #include <string> |
26 | 27 | #include <utility> |
27 | 28 | #include <variant> |
|
31 | 32 | #include "absl/log/absl_log.h" // from @com_google_absl |
32 | 33 | #include "absl/status/status.h" // from @com_google_absl |
33 | 34 | #include "absl/status/statusor.h" // from @com_google_absl |
| 35 | +#include "absl/strings/ascii.h" // from @com_google_absl |
34 | 36 | #include "absl/strings/match.h" // from @com_google_absl |
| 37 | +#include "absl/strings/numbers.h" // from @com_google_absl |
| 38 | +#include "absl/strings/str_split.h" // from @com_google_absl |
35 | 39 | #include "absl/strings/string_view.h" // from @com_google_absl |
36 | 40 | #include "absl/types/span.h" // from @com_google_absl |
37 | 41 | #include "litert/cc/litert_element_type.h" // from @litert |
@@ -122,8 +126,90 @@ BuildModelResourcesFromLitertLmFormat(const ModelAssets& model_assets) { |
122 | 126 | return ModelResourcesLitertLm::Create(std::move(loader)); |
123 | 127 | } |
124 | 128 |
|
| 129 | +absl::StatusOr<int> GetCacheTensorLayerIndex(absl::string_view tensor_name, |
| 130 | + absl::string_view root_name) { |
| 131 | + if (!absl::StartsWith(tensor_name, root_name)) { |
| 132 | + return absl::InvalidArgumentError("Tensor name does not match cache root."); |
| 133 | + } |
| 134 | + const absl::string_view suffix = tensor_name.substr(root_name.size()); |
| 135 | + int layer_index = 0; |
| 136 | + if (!absl::SimpleAtoi(suffix, &layer_index) || layer_index < 0) { |
| 137 | + return absl::InvalidArgumentError( |
| 138 | + "Failed to parse cache tensor layer index."); |
| 139 | + } |
| 140 | + return layer_index; |
| 141 | +} |
| 142 | + |
125 | 143 | } // namespace |
126 | 144 |
|
| 145 | +absl::StatusOr<CacheAdapterMetadata> ParseCacheAdapterMetadata( |
| 146 | + std::optional<std::string> cache_adapter_kind, |
| 147 | + std::optional<std::string> cache_layer_types) { |
| 148 | + CacheAdapterMetadata metadata; |
| 149 | + if (!cache_adapter_kind.has_value() || cache_adapter_kind->empty()) { |
| 150 | + return metadata; |
| 151 | + } |
| 152 | + metadata.cache_adapter_kind = |
| 153 | + absl::AsciiStrToLower(absl::StripAsciiWhitespace(*cache_adapter_kind)); |
| 154 | + if (metadata.cache_adapter_kind.empty()) { |
| 155 | + return metadata; |
| 156 | + } |
| 157 | + if (metadata.cache_adapter_kind == "qwen35") { |
| 158 | + RET_CHECK(cache_layer_types.has_value() && !cache_layer_types->empty()) |
| 159 | + .SetCode(absl::StatusCode::kFailedPrecondition) |
| 160 | + << "cache_layer_types metadata is required for qwen35 cache adapter."; |
| 161 | + for (absl::string_view part : absl::StrSplit(*cache_layer_types, ',')) { |
| 162 | + const std::string layer_type = |
| 163 | + std::string(absl::AsciiStrToLower(absl::StripAsciiWhitespace(part))); |
| 164 | + if (!layer_type.empty()) { |
| 165 | + metadata.cache_layer_types.push_back(layer_type); |
| 166 | + } |
| 167 | + } |
| 168 | + RET_CHECK(!metadata.cache_layer_types.empty()) |
| 169 | + .SetCode(absl::StatusCode::kFailedPrecondition) |
| 170 | + << "cache_layer_types metadata is empty for qwen35 cache adapter."; |
| 171 | + } |
| 172 | + return metadata; |
| 173 | +} |
| 174 | + |
| 175 | +absl::StatusOr<CacheAdapterMetadata> GetCacheAdapterMetadata( |
| 176 | + ModelResources& resources, ModelType model_type) { |
| 177 | + return ParseCacheAdapterMetadata( |
| 178 | + resources.GetTFLiteModelMetadataValue(model_type, "cache_adapter_kind"), |
| 179 | + resources.GetTFLiteModelMetadataValue(model_type, "cache_layer_types")); |
| 180 | +} |
| 181 | + |
| 182 | +absl::StatusOr<bool> IsSequenceCacheTensorName( |
| 183 | + absl::string_view tensor_name, absl::string_view k_root_name, |
| 184 | + absl::string_view v_root_name, const CacheAdapterMetadata& metadata) { |
| 185 | + if (!metadata.has_adapter()) { |
| 186 | + return true; |
| 187 | + } |
| 188 | + if (metadata.cache_adapter_kind != "qwen35") { |
| 189 | + return true; |
| 190 | + } |
| 191 | + const bool is_key_tensor = absl::StartsWith(tensor_name, k_root_name); |
| 192 | + const bool is_value_tensor = absl::StartsWith(tensor_name, v_root_name); |
| 193 | + RET_CHECK(is_key_tensor || is_value_tensor) |
| 194 | + .SetCode(absl::StatusCode::kInvalidArgument) |
| 195 | + << "Tensor name is not a recognized KV cache tensor: " << tensor_name; |
| 196 | + const absl::string_view root_name = is_key_tensor ? k_root_name : v_root_name; |
| 197 | + ASSIGN_OR_RETURN(int layer_index, |
| 198 | + GetCacheTensorLayerIndex(tensor_name, root_name)); |
| 199 | + RET_CHECK_LT(layer_index, metadata.cache_layer_types.size()) |
| 200 | + .SetCode(absl::StatusCode::kFailedPrecondition) |
| 201 | + << "Cache metadata layer_types does not cover tensor: " << tensor_name; |
| 202 | + const absl::string_view layer_type = metadata.cache_layer_types[layer_index]; |
| 203 | + if (layer_type == "full_attention") { |
| 204 | + return true; |
| 205 | + } |
| 206 | + if (layer_type == "linear_attention") { |
| 207 | + return false; |
| 208 | + } |
| 209 | + return absl::FailedPreconditionError( |
| 210 | + "Unsupported qwen35 cache layer type in metadata."); |
| 211 | +} |
| 212 | + |
127 | 213 | absl::StatusOr<ModelSignatures> GetModelSignaturesFromInputOutputNames( |
128 | 214 | const std::vector<absl::string_view>& input_names, |
129 | 215 | const std::vector<absl::string_view>& output_names, bool strict) { |
|
0 commit comments