|
14 | 14 |
|
15 | 15 | #include "runtime/conversation/model_data_processor/fastvlm_data_processor.h" |
16 | 16 |
|
| 17 | +#include <deque> |
17 | 18 | #include <memory> |
18 | | -#include <optional> |
19 | 19 | #include <string> |
20 | 20 | #include <utility> |
| 21 | +#include <variant> |
21 | 22 | #include <vector> |
22 | 23 |
|
23 | 24 | #include "absl/memory/memory.h" // from @com_google_absl |
24 | 25 | #include "absl/status/status.h" // from @com_google_absl |
25 | 26 | #include "absl/status/statusor.h" // from @com_google_absl |
| 27 | +#include "absl/strings/string_view.h" // from @com_google_absl |
26 | 28 | #include "nlohmann/json.hpp" // from @nlohmann_json |
27 | | -#include "runtime/components/tokenizer.h" |
| 29 | +#include "litert/cc/litert_layout.h" // from @litert |
| 30 | +#include "runtime/components/preprocessor/image_preprocessor.h" |
| 31 | +#include "runtime/components/preprocessor/stb_image_preprocessor.h" |
| 32 | +#include "runtime/components/prompt_template.h" |
28 | 33 | #include "runtime/conversation/io_types.h" |
| 34 | +#include "runtime/conversation/model_data_processor/data_utils.h" |
29 | 35 | #include "runtime/conversation/model_data_processor/fastvlm_data_processor_config.h" |
30 | | -#include "runtime/conversation/model_data_processor/gemma3_data_processor.h" |
31 | | -#include "runtime/conversation/model_data_processor/gemma3_data_processor_config.h" |
32 | 36 | #include "runtime/conversation/model_data_processor/model_data_processor.h" |
33 | 37 | #include "runtime/engine/io_types.h" |
| 38 | +#include "runtime/util/memory_mapped_file.h" |
34 | 39 | #include "runtime/util/status_macros.h" |
| 40 | +#include "re2/re2.h" // from @com_googlesource_code_re2 |
35 | 41 |
|
36 | 42 | namespace litert::lm { |
37 | 43 |
|
| 44 | +namespace { |
| 45 | + |
| 46 | +using ::nlohmann::ordered_json; |
| 47 | + |
| 48 | +bool IsImage(absl::string_view part) { return part == "<image_soft_token>"; } |
| 49 | + |
| 50 | +} // namespace |
| 51 | + |
38 | 52 | absl::StatusOr<std::unique_ptr<FastVlmDataProcessor>> |
39 | | -FastVlmDataProcessor::Create( |
40 | | - FastVlmDataProcessorConfig config, std::optional<Preface> preface, |
41 | | - const Tokenizer* tokenizer, |
42 | | - const std::vector<std::vector<int>>& stop_token_ids, |
43 | | - bool enable_constrained_decoding) { |
44 | | - Gemma3DataProcessorConfig gemma3_config; |
45 | | - gemma3_config.boi_token = config.boi_token; |
46 | | - gemma3_config.eoi_token = config.eoi_token; |
47 | | - gemma3_config.image_tensor_height = config.image_tensor_height; |
48 | | - gemma3_config.image_tensor_width = config.image_tensor_width; |
49 | | - |
50 | | - ASSIGN_OR_RETURN(auto impl, Gemma3DataProcessor::Create( |
51 | | - gemma3_config, preface, tokenizer, |
52 | | - stop_token_ids, enable_constrained_decoding)); |
53 | | - return absl::WrapUnique(new FastVlmDataProcessor(config, std::move(impl))); |
| 53 | +FastVlmDataProcessor::Create(FastVlmDataProcessorConfig config, |
| 54 | + const PromptTemplateCapabilities& capabilities) { |
| 55 | + return absl::WrapUnique(new FastVlmDataProcessor( |
| 56 | + config, capabilities, std::make_unique<StbImagePreprocessor>())); |
| 57 | +} |
| 58 | + |
| 59 | +absl::StatusOr<ordered_json> FastVlmDataProcessor::MessageToTemplateInput( |
| 60 | + const ordered_json& message) const { |
| 61 | + if (message["content"].is_string() && capabilities_.requires_typed_content) { |
| 62 | + return ordered_json::object( |
| 63 | + {{"role", message["role"]}, |
| 64 | + {"content", ordered_json::array( |
| 65 | + {{{"type", "text"}, {"text", message["content"]}}})}}); |
| 66 | + } else if (message["content"].is_array() && message["content"].size() == 1 && |
| 67 | + message["content"][0]["type"] == "text" && |
| 68 | + !capabilities_.requires_typed_content) { |
| 69 | + return ordered_json::object({{"role", message["role"]}, |
| 70 | + {"content", message["content"][0]["text"]}}); |
| 71 | + } else { |
| 72 | + return message; |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +absl::StatusOr<ordered_json> FastVlmDataProcessor::FormatTools( |
| 77 | + const ordered_json& tools) const { |
| 78 | + return absl::UnimplementedError("FastVLM does not support tool calling."); |
54 | 79 | } |
55 | 80 |
|
56 | 81 | absl::StatusOr<std::vector<InputData>> |
57 | 82 | FastVlmDataProcessor::ToInputDataVectorImpl( |
58 | | - const std::string& rendered_template_prompt, |
59 | | - const nlohmann::ordered_json& messages, |
| 83 | + const std::string& rendered_template_prompt, const ordered_json& messages, |
60 | 84 | const FastVlmDataProcessorArguments& args) const { |
61 | | - return impl_->ToInputDataVector(rendered_template_prompt, messages, |
62 | | - Gemma3DataProcessorArguments{}); |
| 85 | + std::vector<InputData> input_data; |
| 86 | + std::deque<std::unique_ptr<MemoryMappedFile>> image_files; |
| 87 | + |
| 88 | + for (const auto& message : messages) { |
| 89 | + if (message.contains("content") && message["content"].is_array()) { |
| 90 | + for (const auto& item : message["content"]) { |
| 91 | + if (item.is_string()) { |
| 92 | + continue; |
| 93 | + } |
| 94 | + ASSIGN_OR_RETURN(std::unique_ptr<MemoryMappedFile> mmap_file, |
| 95 | + LoadItemData(item)); |
| 96 | + if (item["type"] == "image") { |
| 97 | + image_files.push_back(std::move(mmap_file)); |
| 98 | + } |
| 99 | + } |
| 100 | + } |
| 101 | + } |
| 102 | + |
| 103 | + RE2 re_delimiter("(<image_soft_token>)"); |
| 104 | + absl::string_view prompt_view(rendered_template_prompt); |
| 105 | + const char* start = prompt_view.data(); |
| 106 | + std::string part; |
| 107 | + ImagePreprocessParameter image_params; |
| 108 | + image_params.SetTargetDimensions(Dimensions( |
| 109 | + {1, config_.image_tensor_height, config_.image_tensor_width, 3})); |
| 110 | + |
| 111 | + while (RE2::FindAndConsume(&prompt_view, re_delimiter, &part)) { |
| 112 | + absl::string_view text_part(start, prompt_view.data() - part.size()); |
| 113 | + start = prompt_view.data(); |
| 114 | + if (IsImage(part)) { |
| 115 | + input_data.emplace_back(InputText(std::string(text_part))); |
| 116 | + |
| 117 | + if (image_files.empty()) { |
| 118 | + return absl::InvalidArgumentError( |
| 119 | + "Provided less images than expected in the prompt."); |
| 120 | + } |
| 121 | + auto image_file = std::move(image_files.front()); |
| 122 | + image_files.pop_front(); |
| 123 | + ASSIGN_OR_RETURN(auto preprocessed_image, |
| 124 | + image_preprocessor_->Preprocess( |
| 125 | + InputImage(std::string( |
| 126 | + static_cast<const char*>(image_file->data()), |
| 127 | + image_file->length())), |
| 128 | + image_params)); |
| 129 | + input_data.emplace_back(InputImage(std::move(preprocessed_image))); |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + if (!image_files.empty()) { |
| 134 | + return absl::InvalidArgumentError( |
| 135 | + "Provided more images than expected in the prompt."); |
| 136 | + } |
| 137 | + |
| 138 | + if (!prompt_view.empty()) { |
| 139 | + input_data.push_back(InputText(std::string(prompt_view))); |
| 140 | + } |
| 141 | + |
| 142 | + return input_data; |
63 | 143 | } |
64 | 144 |
|
65 | 145 | absl::StatusOr<Message> FastVlmDataProcessor::ToMessageImpl( |
66 | 146 | const Responses& responses, |
67 | 147 | const FastVlmDataProcessorArguments& args) const { |
68 | | - return impl_->ToMessage(responses, Gemma3DataProcessorArguments{}); |
| 148 | + absl::string_view response_text = responses.GetTexts()[0]; |
| 149 | + ordered_json content = ordered_json::array( |
| 150 | + {{{"type", "text"}, {"text", std::string(response_text)}}}); |
| 151 | + return ordered_json::object({{"role", "assistant"}, {"content", content}}); |
69 | 152 | } |
70 | 153 |
|
71 | 154 | absl::Status FastVlmDataProcessor::CloneStateImpl( |
72 | 155 | const TypeSafeModelDataProcessor<FastVlmDataProcessorConfig, |
73 | 156 | FastVlmDataProcessorArguments>& other) { |
74 | | - const FastVlmDataProcessor& other_fastvlm = |
75 | | - static_cast<const FastVlmDataProcessor&>(other); |
76 | | - return impl_->CloneState(*other_fastvlm.impl_); |
| 157 | + return absl::OkStatus(); |
77 | 158 | } |
78 | 159 |
|
79 | 160 | } // namespace litert::lm |
0 commit comments