Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = ENGINE_COMMON_DEPS + [
"//runtime/core:engine_impl_cpu_only",
"//runtime/core:engine_advanced_impl_cpu_only",
],
)

Expand Down
3 changes: 3 additions & 0 deletions c/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ absl::AnyInvocable<void(absl::StatusOr<litert::lm::Responses>)> CreateCallback(
litert::lm::TaskState::kMaxNumTokensReached) {
callback(callback_data, /*text=*/nullptr, /*is_final=*/true,
"Max number of tokens reached.");
} else if (responses->GetTaskState() == litert::lm::TaskState::kCancelled) {
callback(callback_data, /*text=*/nullptr, /*is_final=*/true,
"CANCELLED.");
} else {
for (const auto& text : responses->GetTexts()) {
callback(callback_data, text.data(), /*is_final=*/false,
Expand Down
9 changes: 8 additions & 1 deletion kotlin/java/com/google/ai/edge/litertlm/jni/litertlm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,14 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeGenerateContentStream)(
(jint)absl::StatusCode::kInternal, message);
env->DeleteLocalRef(message);
cleanup_callback_ref();
} else {
} else if (responses->GetTaskState() ==
litert::lm::TaskState::kCancelled) {
jstring message = NewStringStandardUTF(env, "Process cancelled.");
env->CallVoidMethod(callback_global, on_error_mid, (jint)1,
message);
env->DeleteLocalRef(message);
cleanup_callback_ref();
} else if (!responses->GetTexts().empty()) {
jstring response_jstr =
NewStringStandardUTF(env, responses->GetTexts()[0]);
env->CallVoidMethod(callback_global, on_response_mid,
Expand Down
54 changes: 14 additions & 40 deletions runtime/conversation/conversation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1915,7 +1915,6 @@ TEST_P(ConversationTest, GetBenchmarkInfo) {
ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault(
model_assets, Backend::CPU));
engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache");
engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(15);
proto::BenchmarkParams benchmark_params;
engine_settings.GetMutableBenchmarkParams() = benchmark_params;
ASSERT_OK_AND_ASSIGN(auto engine,
Expand All @@ -1931,17 +1930,21 @@ TEST_P(ConversationTest, GetBenchmarkInfo) {
.Build(*engine));
ASSERT_OK_AND_ASSIGN(auto conversation,
Conversation::Create(*engine, config));
ASSERT_OK_AND_ASSIGN(const Message message_1,
conversation->SendMessage(Message{
{"role", "user"}, {"content", "Hello world!"}}));
ASSERT_OK_AND_ASSIGN(
const Message message_1,
conversation->SendMessage(
Message{{"role", "user"}, {"content", "Hello world!"}},
{.max_output_tokens = 8}));
ASSERT_OK_AND_ASSIGN(const BenchmarkInfo benchmark_info_1,
conversation->GetBenchmarkInfo());
EXPECT_EQ(benchmark_info_1.GetTotalPrefillTurns(),
prefill_preface_on_init_ ? 2 : 1);

ASSERT_OK_AND_ASSIGN(const Message message_2,
conversation->SendMessage(Message{
{"role", "user"}, {"content", "Hello world!"}}));
ASSERT_OK_AND_ASSIGN(
const Message message_2,
conversation->SendMessage(
Message{{"role", "user"}, {"content", "Hello world!"}},
{.max_output_tokens = 8}));
ASSERT_OK_AND_ASSIGN(const BenchmarkInfo benchmark_info_2,
conversation->GetBenchmarkInfo());
EXPECT_EQ(benchmark_info_2.GetTotalPrefillTurns(),
Expand Down Expand Up @@ -2206,9 +2209,6 @@ TEST_P(ConversationCancellationTest, CancelProcessWithBenchmarkInfo) {
ASSERT_OK_AND_ASSIGN(auto engine_settings, EngineSettings::CreateDefault(
model_assets, Backend::CPU));
engine_settings.GetMutableMainExecutorSettings().SetCacheDir(":nocache");
// Set a large max num tokens to ensure the decoding is not finished before
// cancellation.
engine_settings.GetMutableMainExecutorSettings().SetMaxNumTokens(20);
if (use_benchmark_info) {
proto::BenchmarkParams benchmark_params;
engine_settings.GetMutableBenchmarkParams() = benchmark_params;
Expand All @@ -2221,10 +2221,10 @@ TEST_P(ConversationCancellationTest, CancelProcessWithBenchmarkInfo) {

absl::Status status;
absl::Notification done_1;
conversation
->SendMessageAsync(Message{{"role", "user"}, {"content", "Hello world!"}},
CreateCancelledMessageCallback(status, done_1))
.IgnoreError();
ASSERT_OK(conversation->SendMessageAsync(
Message{{"role", "user"}, {"content", "Hello world!"}},
CreateCancelledMessageCallback(status, done_1),
{.max_output_tokens = 128}));
// Wait for a short time to ensure the decoding has started.
absl::SleepFor(absl::Milliseconds(100));
conversation->CancelProcess();
Expand All @@ -2234,32 +2234,6 @@ TEST_P(ConversationCancellationTest, CancelProcessWithBenchmarkInfo) {

// The history should be empty after cancellation.
EXPECT_THAT(conversation->GetHistory().size(), 0);

// Resend the message after cancellation, and it should succeed.
status = absl::OkStatus();
absl::Notification done_2;
conversation
->SendMessageAsync(Message{{"role", "user"}, {"content", "Hello world!"}},
CreateCancelledMessageCallback(status, done_2))
.IgnoreError();
EXPECT_OK(status);
// Wait for the callback to be done.
done_2.WaitForNotificationWithTimeout(absl::Seconds(10));
// Without cancellation, the history should have two messages, user and
// assistant.
auto history = conversation->GetHistory();
ASSERT_EQ(history.size(), 2);
EXPECT_THAT(history[0], testing::Eq(Message{{"role", "user"},
{"content", "Hello world!"}}));
// TODO(b/450903294) - Because the cancellation is not fully rollbacked, the
// assistant message content depends on at which step the cancellation is
// triggered, and that is non-deterministic. Here we only check the role is
// assistant.
EXPECT_EQ(history[1]["role"], "assistant");

conversation->CancelProcess();
// No op after cancellation again.
EXPECT_THAT(conversation->GetHistory().size(), 2);
}

INSTANTIATE_TEST_SUITE_P(ConversationCancellationTest,
Expand Down
102 changes: 52 additions & 50 deletions runtime/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ package_group(
)

ENGINE_IMPL_COMMON_DEPS = [
":session_advanced",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:absl_check",
Expand Down Expand Up @@ -91,41 +92,51 @@ ENGINE_IMPL_COMMON_DEPS = [
})

cc_library(
name = "engine_impl",
srcs = ["engine_impl.cc"],
visibility = [":engine_impl_users"],
name = "engine_advanced_impl",
srcs = ["engine_advanced_impl.cc"],
local_defines = select({
"//conditions:default": [],
}),
deps = ENGINE_IMPL_COMMON_DEPS + [
"//runtime/components:default_static_gpu_samplers",
"//runtime/core:session_basic",
"//runtime/executor:default_static_gpu_accelerator",
"//runtime/executor:vision_executor_settings",
"//runtime/executor:vision_executor_utils",
"//runtime/framework/resource_management:execution_manager",
"//runtime/framework/resource_management:serial_execution_manager",
"//runtime/framework/resource_management:threaded_execution_manager",
] + select({
"//conditions:default": [],
}),
alwayslink = 1,
)

cc_library(
name = "engine_impl_cpu_only",
srcs = ["engine_impl.cc"],
name = "engine_advanced_impl_cpu_only",
srcs = ["engine_advanced_impl.cc"],
deps = ENGINE_IMPL_COMMON_DEPS + [
"//runtime/core:session_basic",
],
"//runtime/executor:vision_executor_settings",
"//runtime/executor:vision_executor_utils",
"//runtime/framework/resource_management:execution_manager",
"//runtime/framework/resource_management:serial_execution_manager",
"//runtime/framework/resource_management:threaded_execution_manager",
] + select({
"//conditions:default": [],
}),
alwayslink = 1,
)

cc_test(
name = "engine_impl_test",
srcs = ["engine_impl_test.cc"],
# The LiteRT GPU path is not ready yet. Only test the CPU path.
args = ["--gunit_filter=-EngineTest.CreateEngineGPU*"],
name = "engine_advanced_impl_test",
srcs = ["engine_advanced_impl_test.cc"],
data = ["//runtime/testdata"],
tags = ["requires-mac-inputs:hard"], # Required for running on Forge on Mac.
defines = ["ENGINE_ADVANCED"],
tags = ["requires-mac-inputs:hard"],
deps = [
":engine_impl", # buildcleaner: keep
":engine_advanced_impl", # buildcleaner: keep
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand All @@ -136,11 +147,25 @@ cc_test(
"//runtime/executor:executor_settings_base",
"//runtime/executor:llm_executor_settings",
"//runtime/proto:sampler_params_cc_proto",
"//runtime/util:litert_status_util",
"//runtime/util:scoped_file",
"//runtime/util:test_utils",
],
)

# TODO - b/502275587: Remove these aliases once the migration is complete.
alias(
name = "engine_impl",
actual = ":engine_advanced_impl",
deprecation = "Use engine_advanced_impl instead.",
)

alias(
name = "engine_impl_cpu_only",
actual = ":engine_advanced_impl_cpu_only",
deprecation = "Use engine_advanced_impl_cpu_only instead.",
)

cc_library(
name = "pipeline",
srcs = ["pipeline.cc"],
Expand Down Expand Up @@ -205,17 +230,17 @@ cc_test(
)

cc_library(
name = "session_basic",
srcs = ["session_basic.cc"],
hdrs = ["session_basic.h"],
name = "session_advanced",
srcs = ["session_advanced.cc"],
hdrs = ["session_advanced.h"],
deps = [
":pipeline",
":session_utils",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
Expand All @@ -224,52 +249,27 @@ cc_library(
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@litert//litert/cc:litert_layout",
"@litert//litert/cc:litert_macros",
"@litert//litert/cc:litert_tensor_buffer_types",
"//runtime/components:sampler",
"//runtime/components:sampler_factory",
"//runtime/components:stop_token_detector",
"//runtime/components:tokenizer",
"//runtime/components/constrained_decoding:constraint",
"//runtime/engine:engine_interface",
"//runtime/engine:engine_settings",
"//runtime/engine:io_types",
"//runtime/executor:audio_executor",
"//runtime/executor:executor_settings_base",
"//runtime/executor:llm_executor",
"//runtime/executor:llm_executor_io_types",
"//runtime/executor:vision_executor",
"//runtime/framework:threadpool",
"//runtime/proto:llm_model_type_cc_proto",
"//runtime/framework/resource_management:execution_manager",
"//runtime/proto:sampler_params_cc_proto",
"//runtime/util:convert_tensor_buffer",
"//runtime/util:executor_data_util",
"//runtime/util:litert_status_util",
"//runtime/util:model_type_utils",
"//runtime/util:tensor_buffer_util",
] + select({
"@litert//litert:litert_link_capi_so": [
"@litert//litert/cc:litert_api_with_dynamic_runtime",
],
"//conditions:default": [
"@litert//litert/cc:litert_model",
"@litert//litert/cc:litert_tensor_buffer",
],
}),
],
)

cc_test(
name = "session_basic_test",
srcs = ["session_basic_test.cc"],
name = "session_advanced_test",
srcs = ["session_advanced_test.cc"],
data = [
"//runtime/components/testdata",
"//runtime/testdata",
],
tags = ["requires-mac-inputs:hard"], # Required for running on Forge on Mac.
deps = [
":session_basic",
":session_advanced",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:any_invocable",
Expand All @@ -278,10 +278,10 @@ cc_test(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@litert//litert/cc:litert_tensor_buffer",
"@litert//litert/test:matchers",
"//runtime/components:model_resources",
"//runtime/components:sentencepiece_tokenizer",
"//runtime/components:tokenizer",
"//runtime/components/constrained_decoding:fake_constraint",
Expand All @@ -293,6 +293,8 @@ cc_test(
"//runtime/executor:fake_llm_executor",
"//runtime/executor:llm_executor_io_types",
"//runtime/framework:threadpool",
"//runtime/framework/resource_management:execution_manager",
"//runtime/framework/resource_management:threaded_execution_manager",
"//runtime/util:convert_tensor_buffer",
"//runtime/util:litert_status_util",
"//runtime/util:scoped_file",
Expand Down
Loading
Loading