Skip to content

Commit 2008af0

Browse files
hheydarycopybara-github
authored andcommitted
Transitioning to SessionAdvanced.
LiteRT-LM-PiperOrigin-RevId: 897279613
1 parent 1bd571a commit 2008af0

15 files changed

Lines changed: 98 additions & 3567 deletions

File tree

c/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ cc_library(
8181
],
8282
visibility = ["//visibility:public"],
8383
deps = ENGINE_COMMON_DEPS + [
84-
"//runtime/core:engine_impl_cpu_only",
84+
"//runtime/core:engine_advanced_impl_cpu_only",
8585
],
8686
)
8787

c/engine.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ absl::AnyInvocable<void(absl::StatusOr<litert::lm::Responses>)> CreateCallback(
6161
litert::lm::TaskState::kMaxNumTokensReached) {
6262
callback(callback_data, /*text=*/nullptr, /*is_final=*/true,
6363
"Max number of tokens reached.");
64+
} else if (responses->GetTaskState() == litert::lm::TaskState::kCancelled) {
65+
callback(callback_data, /*text=*/nullptr, /*is_final=*/true,
66+
"CANCELLED.");
6467
} else {
6568
for (const auto& text : responses->GetTexts()) {
6669
callback(callback_data, text.data(), /*is_final=*/false,

kotlin/java/com/google/ai/edge/litertlm/jni/litertlm.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,14 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeGenerateContentStream)(
750750
(jint)absl::StatusCode::kInternal, message);
751751
env->DeleteLocalRef(message);
752752
cleanup_callback_ref();
753-
} else {
753+
} else if (responses->GetTaskState() ==
754+
litert::lm::TaskState::kCancelled) {
755+
jstring message = NewStringStandardUTF(env, "Process cancelled.");
756+
env->CallVoidMethod(callback_global, on_error_mid, (jint)1,
757+
message);
758+
env->DeleteLocalRef(message);
759+
cleanup_callback_ref();
760+
} else if (!responses->GetTexts().empty()) {
754761
jstring response_jstr =
755762
NewStringStandardUTF(env, responses->GetTexts()[0]);
756763
env->CallVoidMethod(callback_global, on_response_mid,

runtime/conversation/conversation_test.cc

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,7 +1905,6 @@ TEST_P(ConversationTest, GetBenchmarkInfo) {
19051905
ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault(
19061906
model_assets, Backend::CPU));
19071907
engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache");
1908-
engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(15);
19091908
proto::BenchmarkParams benchmark_params;
19101909
engine_settings.GetMutableBenchmarkParams() = benchmark_params;
19111910
ASSERT_OK_AND_ASSIGN(auto engine, EngineFactory::CreateAny(engine_settings));
@@ -1920,17 +1919,21 @@ TEST_P(ConversationTest, GetBenchmarkInfo) {
19201919
.Build(*engine));
19211920
ASSERT_OK_AND_ASSIGN(auto conversation,
19221921
Conversation::Create(*engine, config));
1923-
ASSERT_OK_AND_ASSIGN(const Message message_1,
1924-
conversation->SendMessage(Message{
1925-
{"role", "user"}, {"content", "Hello world!"}}));
1922+
ASSERT_OK_AND_ASSIGN(
1923+
const Message message_1,
1924+
conversation->SendMessage(
1925+
Message{{"role", "user"}, {"content", "Hello world!"}},
1926+
{.max_output_tokens = 8}));
19261927
ASSERT_OK_AND_ASSIGN(const BenchmarkInfo benchmark_info_1,
19271928
conversation->GetBenchmarkInfo());
19281929
EXPECT_EQ(benchmark_info_1.GetTotalPrefillTurns(),
19291930
prefill_preface_on_init_ ? 2 : 1);
19301931

1931-
ASSERT_OK_AND_ASSIGN(const Message message_2,
1932-
conversation->SendMessage(Message{
1933-
{"role", "user"}, {"content", "Hello world!"}}));
1932+
ASSERT_OK_AND_ASSIGN(
1933+
const Message message_2,
1934+
conversation->SendMessage(
1935+
Message{{"role", "user"}, {"content", "Hello world!"}},
1936+
{.max_output_tokens = 8}));
19341937
ASSERT_OK_AND_ASSIGN(const BenchmarkInfo benchmark_info_2,
19351938
conversation->GetBenchmarkInfo());
19361939
EXPECT_EQ(benchmark_info_2.GetTotalPrefillTurns(),
@@ -2194,9 +2197,6 @@ TEST_P(ConversationCancellationTest, CancelProcessWithBenchmarkInfo) {
21942197
ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault(
21952198
model_assets, Backend::CPU));
21962199
engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache");
2197-
// Set a large max num tokens to ensure the decoding is not finished before
2198-
// cancellation.
2199-
engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(20);
22002200
if (use_benchmark_info) {
22012201
proto::BenchmarkParams benchmark_params;
22022202
engine_settings.GetMutableBenchmarkParams() = benchmark_params;
@@ -2208,10 +2208,10 @@ TEST_P(ConversationCancellationTest, CancelProcessWithBenchmarkInfo) {
22082208

22092209
absl::Status status;
22102210
absl::Notification done_1;
2211-
conversation
2212-
->SendMessageAsync(Message{{"role", "user"}, {"content", "Hello world!"}},
2213-
CreateCancelledMessageCallback(status, done_1))
2214-
.IgnoreError();
2211+
ASSERT_OK(conversation->SendMessageAsync(
2212+
Message{{"role", "user"}, {"content", "Hello world!"}},
2213+
CreateCancelledMessageCallback(status, done_1),
2214+
{.max_output_tokens = 128}));
22152215
// Wait for a short time to ensure the decoding has started.
22162216
absl::SleepFor(absl::Milliseconds(100));
22172217
conversation->CancelProcess();
@@ -2221,32 +2221,6 @@ TEST_P(ConversationCancellationTest, CancelProcessWithBenchmarkInfo) {
22212221

22222222
// The history should be empty after cancellation.
22232223
EXPECT_THAT(conversation->GetHistory().size(), 0);
2224-
2225-
// Resend the message after cancellation, and it should succeed.
2226-
status = absl::OkStatus();
2227-
absl::Notification done_2;
2228-
conversation
2229-
->SendMessageAsync(Message{{"role", "user"}, {"content", "Hello world!"}},
2230-
CreateCancelledMessageCallback(status, done_2))
2231-
.IgnoreError();
2232-
EXPECT_OK(status);
2233-
// Wait for the callback to be done.
2234-
done_2.WaitForNotificationWithTimeout(absl::Seconds(10));
2235-
// Without cancellation, the history should have two messages, user and
2236-
// assistant.
2237-
auto history = conversation->GetHistory();
2238-
ASSERT_EQ(history.size(), 2);
2239-
EXPECT_THAT(history[0], testing::Eq(Message{{"role", "user"},
2240-
{"content", "Hello world!"}}));
2241-
// TODO(b/450903294) - Because the cancellation is not fully rollbacked, the
2242-
// assistant message content depends on at which step the cancellation is
2243-
// triggered, and that is non-deterministic. Here we only check the role is
2244-
// assistant.
2245-
EXPECT_EQ(history[1]["role"], "assistant");
2246-
2247-
conversation->CancelProcess();
2248-
// No op after cancellation again.
2249-
EXPECT_THAT(conversation->GetHistory().size(), 2);
22502224
}
22512225

22522226
INSTANTIATE_TEST_SUITE_P(ConversationCancellationTest,

runtime/core/BUILD

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ package_group(
4343
)
4444

4545
ENGINE_IMPL_COMMON_DEPS = [
46+
":session_advanced",
4647
"@com_google_absl//absl/base:no_destructor",
4748
"@com_google_absl//absl/log",
4849
"@com_google_absl//absl/log:absl_check",
@@ -91,41 +92,51 @@ ENGINE_IMPL_COMMON_DEPS = [
9192
})
9293

9394
cc_library(
94-
name = "engine_impl",
95-
srcs = ["engine_impl.cc"],
96-
visibility = [":engine_impl_users"],
95+
name = "engine_advanced_impl",
96+
srcs = ["engine_advanced_impl.cc"],
97+
local_defines = select({
98+
"//conditions:default": [],
99+
}),
97100
deps = ENGINE_IMPL_COMMON_DEPS + [
98101
"//runtime/components:default_static_gpu_samplers",
99-
"//runtime/core:session_basic",
100102
"//runtime/executor:default_static_gpu_accelerator",
103+
"//runtime/executor:vision_executor_settings",
104+
"//runtime/executor:vision_executor_utils",
105+
"//runtime/framework/resource_management:execution_manager",
106+
"//runtime/framework/resource_management:serial_execution_manager",
107+
"//runtime/framework/resource_management:threaded_execution_manager",
101108
] + select({
102109
"//conditions:default": [],
103110
}),
104111
alwayslink = 1,
105112
)
106113

107114
cc_library(
108-
name = "engine_impl_cpu_only",
109-
srcs = ["engine_impl.cc"],
115+
name = "engine_advanced_impl_cpu_only",
116+
srcs = ["engine_advanced_impl.cc"],
110117
deps = ENGINE_IMPL_COMMON_DEPS + [
111-
"//runtime/core:session_basic",
112-
],
118+
"//runtime/executor:vision_executor_settings",
119+
"//runtime/executor:vision_executor_utils",
120+
"//runtime/framework/resource_management:execution_manager",
121+
"//runtime/framework/resource_management:serial_execution_manager",
122+
"//runtime/framework/resource_management:threaded_execution_manager",
123+
] + select({
124+
"//conditions:default": [],
125+
}),
113126
alwayslink = 1,
114127
)
115128

116129
cc_test(
117-
name = "engine_impl_test",
118-
srcs = ["engine_impl_test.cc"],
119-
# The LiteRT GPU path is not ready yet. Only test the CPU path.
120-
args = ["--gunit_filter=-EngineTest.CreateEngineGPU*"],
130+
name = "engine_advanced_impl_test",
131+
srcs = ["engine_advanced_impl_test.cc"],
121132
data = ["//runtime/testdata"],
122-
tags = ["requires-mac-inputs:hard"], # Required for running on Forge on Mac.
133+
defines = ["ENGINE_ADVANCED"],
134+
tags = ["requires-mac-inputs:hard"],
123135
deps = [
124-
":engine_impl", # buildcleaner: keep
136+
":engine_advanced_impl", # buildcleaner: keep
125137
"@com_google_googletest//:gtest_main",
126138
"@com_google_absl//absl/cleanup",
127139
"@com_google_absl//absl/log:absl_check",
128-
"@com_google_absl//absl/log:absl_log",
129140
"@com_google_absl//absl/status",
130141
"@com_google_absl//absl/status:statusor",
131142
"@com_google_absl//absl/strings",
@@ -136,11 +147,25 @@ cc_test(
136147
"//runtime/executor:executor_settings_base",
137148
"//runtime/executor:llm_executor_settings",
138149
"//runtime/proto:sampler_params_cc_proto",
150+
"//runtime/util:litert_status_util",
139151
"//runtime/util:scoped_file",
140152
"//runtime/util:test_utils",
141153
],
142154
)
143155

156+
# TODO - b/502275587: Remove these aliases once the migration is complete.
157+
alias(
158+
name = "engine_impl",
159+
actual = ":engine_advanced_impl",
160+
deprecation = "Use engine_advanced_impl instead.",
161+
)
162+
163+
alias(
164+
name = "engine_impl_cpu_only",
165+
actual = ":engine_advanced_impl_cpu_only",
166+
deprecation = "Use engine_advanced_impl_cpu_only instead.",
167+
)
168+
144169
cc_library(
145170
name = "pipeline",
146171
srcs = ["pipeline.cc"],
@@ -205,17 +230,17 @@ cc_test(
205230
)
206231

207232
cc_library(
208-
name = "session_basic",
209-
srcs = ["session_basic.cc"],
210-
hdrs = ["session_basic.h"],
233+
name = "session_advanced",
234+
srcs = ["session_advanced.cc"],
235+
hdrs = ["session_advanced.h"],
211236
deps = [
212-
":pipeline",
213237
":session_utils",
214238
"@com_google_absl//absl/base:core_headers",
215239
"@com_google_absl//absl/base:nullability",
216240
"@com_google_absl//absl/container:flat_hash_map",
217241
"@com_google_absl//absl/container:flat_hash_set",
218242
"@com_google_absl//absl/functional:any_invocable",
243+
"@com_google_absl//absl/log",
219244
"@com_google_absl//absl/log:absl_log",
220245
"@com_google_absl//absl/memory",
221246
"@com_google_absl//absl/status",
@@ -224,52 +249,27 @@ cc_library(
224249
"@com_google_absl//absl/strings:string_view",
225250
"@com_google_absl//absl/synchronization",
226251
"@com_google_absl//absl/time",
227-
"@com_google_absl//absl/types:span",
228-
"@litert//litert/cc:litert_layout",
229-
"@litert//litert/cc:litert_macros",
230-
"@litert//litert/cc:litert_tensor_buffer_types",
231-
"//runtime/components:sampler",
232-
"//runtime/components:sampler_factory",
233-
"//runtime/components:stop_token_detector",
234252
"//runtime/components:tokenizer",
235-
"//runtime/components/constrained_decoding:constraint",
236253
"//runtime/engine:engine_interface",
237254
"//runtime/engine:engine_settings",
238255
"//runtime/engine:io_types",
239-
"//runtime/executor:audio_executor",
240-
"//runtime/executor:executor_settings_base",
241-
"//runtime/executor:llm_executor",
242256
"//runtime/executor:llm_executor_io_types",
243-
"//runtime/executor:vision_executor",
244-
"//runtime/framework:threadpool",
245-
"//runtime/proto:llm_model_type_cc_proto",
257+
"//runtime/framework/resource_management:execution_manager",
246258
"//runtime/proto:sampler_params_cc_proto",
247-
"//runtime/util:convert_tensor_buffer",
248-
"//runtime/util:executor_data_util",
249259
"//runtime/util:litert_status_util",
250-
"//runtime/util:model_type_utils",
251-
"//runtime/util:tensor_buffer_util",
252-
] + select({
253-
"@litert//litert:litert_link_capi_so": [
254-
"@litert//litert/cc:litert_api_with_dynamic_runtime",
255-
],
256-
"//conditions:default": [
257-
"@litert//litert/cc:litert_model",
258-
"@litert//litert/cc:litert_tensor_buffer",
259-
],
260-
}),
260+
],
261261
)
262262

263263
cc_test(
264-
name = "session_basic_test",
265-
srcs = ["session_basic_test.cc"],
264+
name = "session_advanced_test",
265+
srcs = ["session_advanced_test.cc"],
266266
data = [
267267
"//runtime/components/testdata",
268268
"//runtime/testdata",
269269
],
270270
tags = ["requires-mac-inputs:hard"], # Required for running on Forge on Mac.
271271
deps = [
272-
":session_basic",
272+
":session_advanced",
273273
"@com_google_googletest//:gtest_main",
274274
"@com_google_absl//absl/container:flat_hash_map",
275275
"@com_google_absl//absl/functional:any_invocable",
@@ -278,10 +278,10 @@ cc_test(
278278
"@com_google_absl//absl/status:statusor",
279279
"@com_google_absl//absl/strings",
280280
"@com_google_absl//absl/strings:string_view",
281-
"@com_google_absl//absl/synchronization",
282281
"@com_google_absl//absl/time",
283282
"@litert//litert/cc:litert_tensor_buffer",
284283
"@litert//litert/test:matchers",
284+
"//runtime/components:model_resources",
285285
"//runtime/components:sentencepiece_tokenizer",
286286
"//runtime/components:tokenizer",
287287
"//runtime/components/constrained_decoding:fake_constraint",
@@ -293,6 +293,8 @@ cc_test(
293293
"//runtime/executor:fake_llm_executor",
294294
"//runtime/executor:llm_executor_io_types",
295295
"//runtime/framework:threadpool",
296+
"//runtime/framework/resource_management:execution_manager",
297+
"//runtime/framework/resource_management:threaded_execution_manager",
296298
"//runtime/util:convert_tensor_buffer",
297299
"//runtime/util:litert_status_util",
298300
"//runtime/util:scoped_file",

0 commit comments

Comments
 (0)