Skip to content

Commit 2831aff

Browse files
hheydarycopybara-github
authored andcommitted
Internal change
LiteRT-LM-PiperOrigin-RevId: 908791856
1 parent 00f8059 commit 2831aff

17 files changed

Lines changed: 763 additions & 299 deletions

runtime/conversation/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ cc_test(
181181
"//runtime/components/constrained_decoding:bitmap",
182182
"//runtime/components/constrained_decoding:constraint",
183183
"//runtime/components/constrained_decoding:external_constraint_config",
184-
"//runtime/core:session_factory", # buildcleaner: keep
185184
"//runtime/engine:engine_factory",
186185
"//runtime/engine:engine_interface",
187186
"//runtime/engine:engine_settings",

runtime/core/BUILD

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ package_group(
4343
)
4444

4545
ENGINE_IMPL_COMMON_DEPS = [
46-
":session_factory",
4746
"@com_google_absl//absl/base:no_destructor",
4847
"@com_google_absl//absl/log",
4948
"@com_google_absl//absl/log:absl_check",
@@ -97,6 +96,7 @@ cc_library(
9796
visibility = [":engine_impl_users"],
9897
deps = ENGINE_IMPL_COMMON_DEPS + [
9998
"//runtime/components:default_static_gpu_samplers",
99+
"//runtime/core:session_basic",
100100
"//runtime/executor:default_static_gpu_accelerator",
101101
] + select({
102102
"//conditions:default": [],
@@ -107,7 +107,9 @@ cc_library(
107107
cc_library(
108108
name = "engine_impl_cpu_only",
109109
srcs = ["engine_impl.cc"],
110-
deps = ENGINE_IMPL_COMMON_DEPS,
110+
deps = ENGINE_IMPL_COMMON_DEPS + [
111+
"//runtime/core:session_basic",
112+
],
111113
alwayslink = 1,
112114
)
113115

@@ -307,44 +309,6 @@ cc_test(
307309
}),
308310
)
309311

310-
cc_library(
311-
name = "session_factory",
312-
srcs = ["session_factory.cc"],
313-
hdrs = ["session_factory.h"],
314-
deps = [
315-
":session_basic",
316-
"@com_google_absl//absl/base:nullability",
317-
"@com_google_absl//absl/status:statusor",
318-
"//runtime/components:tokenizer",
319-
"//runtime/engine:engine_interface",
320-
"//runtime/engine:engine_settings",
321-
"//runtime/engine:io_types",
322-
"//runtime/executor:audio_executor",
323-
"//runtime/executor:llm_executor",
324-
"//runtime/executor:vision_executor",
325-
"//runtime/framework:threadpool",
326-
"//runtime/proto:sampler_params_cc_proto",
327-
"//runtime/util:litert_status_util",
328-
],
329-
)
330-
331-
cc_test(
332-
name = "session_factory_test",
333-
srcs = ["session_factory_test.cc"],
334-
deps = [
335-
":session_factory",
336-
"@com_google_googletest//:gtest_main",
337-
"@com_google_absl//absl/status:statusor",
338-
"@com_google_absl//absl/strings:string_view",
339-
"//runtime/components:tokenizer",
340-
"//runtime/engine:engine_settings",
341-
"//runtime/executor:executor_settings_base",
342-
"//runtime/executor:fake_llm_executor",
343-
"//runtime/framework:threadpool",
344-
"//runtime/util:test_utils",
345-
],
346-
)
347-
348312
cc_library(
349313
name = "session_utils",
350314
srcs = ["session_utils.cc"],

runtime/core/engine_advanced_impl.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include "litert/cc/litert_macros.h" // from @litert
3737
#include "runtime/components/model_resources.h"
3838
#include "runtime/components/tokenizer.h"
39-
#include "runtime/core/session_factory.h"
39+
#include "runtime/core/session_advanced.h"
4040
#include "runtime/engine/engine.h"
4141
#include "runtime/engine/engine_factory.h"
4242
#include "runtime/engine/engine_settings.h"
@@ -184,8 +184,8 @@ class EngineAdvancedImpl : public Engine {
184184

185185
ASSIGN_OR_RETURN(
186186
auto session,
187-
InitializeSessionAdvanced(execution_manager_, tokenizer_.get(), config,
188-
std::move(session_benchmark_info)));
187+
SessionAdvanced::Create(execution_manager_, tokenizer_.get(), config,
188+
std::move(session_benchmark_info)));
189189

190190
if (benchmark_info_.has_value()) {
191191
auto session_benchmark_info_or = session->GetMutableBenchmarkInfo();

runtime/core/engine_impl.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include "litert/cc/litert_macros.h" // from @litert
3737
#include "runtime/components/model_resources.h"
3838
#include "runtime/components/tokenizer.h"
39-
#include "runtime/core/session_factory.h"
39+
#include "runtime/core/session_basic.h"
4040
#include "runtime/engine/engine.h"
4141
#include "runtime/engine/engine_factory.h"
4242
#include "runtime/engine/engine_settings.h"
@@ -187,11 +187,11 @@ class EngineImpl : public Engine {
187187
}
188188
ASSIGN_OR_RETURN(
189189
auto session,
190-
InitializeSessionBasic(executor_.get(), tokenizer_.get(),
191-
/*vision_executor=*/vision_executor_.get(),
192-
/*audio_executor=*/audio_executor_.get(), config,
193-
std::move(session_benchmark_info),
194-
worker_thread_pool_.get()));
190+
SessionBasic::Create(executor_.get(), tokenizer_.get(),
191+
/*vision_executor=*/vision_executor_.get(),
192+
/*audio_executor=*/audio_executor_.get(), config,
193+
std::move(session_benchmark_info),
194+
worker_thread_pool_.get()));
195195
if (benchmark_info_.has_value()) {
196196
auto session_benchmark_info_or = session->GetMutableBenchmarkInfo();
197197
if (session_benchmark_info_or.ok()) {

runtime/core/session_factory.cc

Lines changed: 0 additions & 47 deletions
This file was deleted.

runtime/core/session_factory.h

Lines changed: 0 additions & 49 deletions
This file was deleted.

runtime/core/session_factory_test.cc

Lines changed: 0 additions & 80 deletions
This file was deleted.

runtime/executor/fake_llm_executor.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_MOCK_LLM_EXECUTOR_H_
1616
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_EXECUTOR_MOCK_LLM_EXECUTOR_H_
1717

18+
#include <cstdint>
1819
#include <memory>
1920
#include <optional>
21+
#include <utility>
2022
#include <vector>
2123

2224
#include "absl/status/status.h" // from @com_google_absl
@@ -85,6 +87,23 @@ class FakeLlmExecutor : public LlmExecutor {
8587
};
8688
absl::StatusOr<int> GetCurrentStep() const override { return current_step_; }
8789

90+
absl::StatusOr<std::unique_ptr<LlmContext>> CreateNewContext(
91+
std::optional<uint32_t> lora_id,
92+
RuntimeConfig runtime_config) const override {
93+
return std::make_unique<LlmContext>(
94+
nullptr, std::make_unique<RuntimeConfig>(std::move(runtime_config)),
95+
std::make_unique<RuntimeState>());
96+
};
97+
98+
absl::Status RestoreContext(
99+
std::unique_ptr<LlmContext> llm_context) override {
100+
return absl::OkStatus();
101+
};
102+
103+
absl::StatusOr<const ProcessedTokens*> GetProcessedTokens() const override {
104+
return &processed_tokens_;
105+
}
106+
88107
absl::Status SetCurrentStep(int current_step) override {
89108
current_step_ = current_step;
90109
if (current_step >= prefill_tokens_total_) {
@@ -95,6 +114,11 @@ class FakeLlmExecutor : public LlmExecutor {
95114
return absl::OkStatus();
96115
}
97116

117+
absl::Status UpdateRuntimeConfig(
118+
const RuntimeConfig& runtime_config) override {
119+
return absl::OkStatus();
120+
}
121+
98122
// Sets the status to be returned by the Prefill function.
99123
void SetPrefillStatus(const absl::Status& status) {
100124
prefill_status_ = status;

0 commit comments

Comments
 (0)