Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions docs/api/kotlin/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,15 @@ engine.createConversation(conversationConfig).use { conversation ->

There are three ways to send messages:

- **`sendMessage(contents): Message`**: Synchronous call that blocks until the
model returns a complete response. This is simpler for basic
request/response interactions.
- **`sendMessageAsync(contents, callback)`**: Asynchronous call for streaming
responses. This is better for long-running requests or when you want to
display the response as it's being generated.
- **`sendMessageAsync(contents): Flow<Message>`**: Asynchronous call that
returns a Kotlin Flow for streaming responses. This is the recommended
approach for Coroutine users.
- **`sendMessage(contents, extraContext): Message`**: Synchronous call that
blocks until the model returns a complete response. This is simpler for
basic request/response interactions.
- **`sendMessageAsync(contents, callback, extraContext)`**: Asynchronous call
for streaming responses. This is better for long-running requests or when
you want to display the response as it's being generated.
- **`sendMessageAsync(contents, extraContext): Flow<Message>`**: Asynchronous
call that returns a Kotlin Flow for streaming responses. This is the
recommended approach for Coroutine users.

**Synchronous Example:**

Expand Down Expand Up @@ -456,6 +456,33 @@ To try out tool use, clone the repo and run with
bazel run -c opt //kotlin/java/com/google/ai/edge/litertlm/example:tool -- <abs_model_path>
```

### 7. Extra Template Context Variables

You can pass extra context variables to the prompt template for rendering.
This allows you to customize the model's behavior based on dynamic values.

`extraContext` is an optional `Map<String, Any>` that can be passed to
`sendMessage` and `sendMessageAsync`. These variables are merged with the extra
context provided in the `Preface` (if any), with keys in the message-level
context overwriting those in the `Preface`.

```kotlin
val extraContext = mapOf(
"user_name" to "Alice",
"enable_thinking" to true
)

// Synchronous
val response = conversation.sendMessage("Hello!", extraContext = extraContext)

// Asynchronous with Flow
conversation.sendMessageAsync("Hello!", extraContext = extraContext)
.collect { ... }
```

These variables are used within the Jinja-style prompt templates, e.g.,
`{{ user_name }}` or `{% if enable_thinking %}`.

## Error Handling

API methods can throw `LiteRtLmJniException` for errors from the native layer or
Expand Down
66 changes: 52 additions & 14 deletions kotlin/java/com/google/ai/edge/litertlm/Conversation.kt
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,21 @@ class Conversation(
* [RECURRING_TOOL_CALL_LIMIT] times.
*
* @param message The message to send to the model.
* @param extraContext Optional context used for prompt template rendering.
* @return The model's response message.
* @throws IllegalStateException if the conversation is not alive, if the native layer returns an
* invalid response, or if the tool call limit is exceeded.
* @throws LiteRtLmJniException if an error occurs during the native call.
*/
fun sendMessage(message: Message): Message {
fun sendMessage(message: Message, extraContext: Map<String, Any> = emptyMap()): Message {
checkIsAlive()

var currentMessageJson = message.toJson()
var extraContextJsonString = extraContext.toJsonObject().toString()

for (i in 0..<RECURRING_TOOL_CALL_LIMIT) {
val responseJsonString = LiteRtLmJni.nativeSendMessage(handle, currentMessageJson.toString())
val responseJsonString =
LiteRtLmJni.nativeSendMessage(handle, currentMessageJson.toString(), extraContextJsonString)
val responseJsonObject = JsonParser.parseString(responseJsonString).asJsonObject

if (responseJsonObject.has("tool_calls")) {
Expand All @@ -124,13 +127,14 @@ class Conversation(
* [RECURRING_TOOL_CALL_LIMIT] times.
*
* @param contents The list of contents to send to the model.
* @param extraContext Optional context used for prompt template rendering.
* @return The model's response message.
* @throws IllegalStateException if the conversation is not alive, if the native layer returns an
* invalid response, or if the tool call limit is exceeded.
* @throws LiteRtLmJniException if an error occurs during the native call.
*/
fun sendMessage(contents: Contents): Message {
return sendMessage(Message.user(contents))
fun sendMessage(contents: Contents, extraContext: Map<String, Any> = emptyMap()): Message {
return sendMessage(Message.user(contents), extraContext)
}

/**
Expand All @@ -142,12 +146,14 @@ class Conversation(
* [RECURRING_TOOL_CALL_LIMIT] times.
*
* @param text The text to send to the model.
* @param extraContext Optional context used for prompt template rendering.
* @return The model's response message.
* @throws IllegalStateException if the conversation is not alive, if the native layer returns an
* invalid response, or if the tool call limit is exceeded.
* @throws LiteRtLmJniException if an error occurs during the native call.
*/
fun sendMessage(text: String): Message = sendMessage(Contents.of(text))
fun sendMessage(text: String, extraContext: Map<String, Any> = emptyMap()): Message =
sendMessage(Contents.of(text), extraContext)

/**
* Send a message to the model and returns the response async with a callback.
Expand All @@ -159,14 +165,26 @@ class Conversation(
*
* @param message The message to send to the model.
* @param callback The callback to receive the streaming responses.
* @param extraContext Optional context used for prompt template rendering.
* @throws IllegalStateException if the conversation has already been closed or the content is
* empty.
*/
fun sendMessageAsync(message: Message, callback: MessageCallback) {
fun sendMessageAsync(
message: Message,
callback: MessageCallback,
extraContext: Map<String, Any> = emptyMap(),
) {
checkIsAlive()

val extraContextJsonString = extraContext.toJsonObject().toString()

val jniCallback = JniMessageCallbackImpl(callback)
LiteRtLmJni.nativeSendMessageAsync(handle, message.toJson().toString(), jniCallback)
LiteRtLmJni.nativeSendMessageAsync(
handle,
message.toJson().toString(),
extraContextJsonString,
jniCallback,
)
}

/**
Expand All @@ -179,11 +197,15 @@ class Conversation(
*
* @param contents The list of contents to send to the model.
* @param callback The callback to receive the streaming responses.
* @param extraContext Optional context used for prompt template rendering.
* @throws IllegalStateException if the conversation has already been closed or the content is
* empty.
*/
fun sendMessageAsync(contents: Contents, callback: MessageCallback) =
sendMessageAsync(Message.user(contents), callback)
fun sendMessageAsync(
contents: Contents,
callback: MessageCallback,
extraContext: Map<String, Any> = emptyMap(),
) = sendMessageAsync(Message.user(contents), callback, extraContext)

/**
* Send a text to the model and returns the response async with a callback.
Expand All @@ -195,11 +217,15 @@ class Conversation(
*
* @param text The text to send to the model.
* @param callback The callback to receive the streaming responses.
* @param extraContext Optional context used for prompt template rendering.
* @throws IllegalStateException if the conversation has already been closed or the content is
* empty.
*/
fun sendMessageAsync(text: String, callback: MessageCallback) =
sendMessageAsync(Contents.of(text), callback)
fun sendMessageAsync(
text: String,
callback: MessageCallback,
extraContext: Map<String, Any> = emptyMap(),
) = sendMessageAsync(Contents.of(text), callback, extraContext)

/**
* Sends a message to the model and returns the response async as a [Flow].
Expand All @@ -210,11 +236,15 @@ class Conversation(
* [RECURRING_TOOL_CALL_LIMIT] times.
*
* @param message The message to send to the model.
* @param extraContext Optional context used for prompt template rendering.
* @return A Flow of messages representing the model's response.
* @throws IllegalStateException if the conversation has already been closed or the content is
* empty.
*/
fun sendMessageAsync(message: Message): Flow<Message> = callbackFlow {
fun sendMessageAsync(
message: Message,
extraContext: Map<String, Any> = emptyMap(),
): Flow<Message> = callbackFlow {
sendMessageAsync(
message,
object : MessageCallback {
Expand All @@ -230,6 +260,7 @@ class Conversation(
close(throwable)
}
},
extraContext,
)
awaitClose {}
}
Expand All @@ -243,11 +274,15 @@ class Conversation(
* [RECURRING_TOOL_CALL_LIMIT] times.
*
* @param contents The list of contents to send to the model.
* @param extraContext Optional context used for prompt template rendering.
* @return A Flow of messages representing the model's response.
* @throws IllegalStateException if the conversation has already been closed or the content is
* empty.
*/
fun sendMessageAsync(contents: Contents): Flow<Message> = sendMessageAsync(Message.user(contents))
fun sendMessageAsync(
contents: Contents,
extraContext: Map<String, Any> = emptyMap(),
): Flow<Message> = sendMessageAsync(Message.user(contents), extraContext)

/**
* Sends a text to the model and returns the response async as a [Flow].
Expand All @@ -258,11 +293,13 @@ class Conversation(
* [RECURRING_TOOL_CALL_LIMIT] times.
*
* @param text The text to send to the model.
* @param extraContext Optional context used for prompt template rendering.
* @return A Flow of messages representing the model's response.
* @throws IllegalStateException if the conversation has already been closed or the content is
* empty.
*/
fun sendMessageAsync(text: String): Flow<Message> = sendMessageAsync(Contents.of(text))
fun sendMessageAsync(text: String, extraContext: Map<String, Any> = emptyMap()): Flow<Message> =
sendMessageAsync(Contents.of(text), extraContext)

private fun handleToolCalls(toolCallsJsonObject: JsonObject): JsonObject {
val toolCallsJSONArray = toolCallsJsonObject.getAsJsonArray("tool_calls")
Expand Down Expand Up @@ -328,6 +365,7 @@ class Conversation(
LiteRtLmJni.nativeSendMessageAsync(
handle,
localToolResponse.toString(),
"{}",
this@JniMessageCallbackImpl,
)
pendingToolResponseJSONMessage = null // Clear after sending
Expand Down
7 changes: 6 additions & 1 deletion kotlin/java/com/google/ai/edge/litertlm/LiteRtLmJni.kt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ internal object LiteRtLmJni {
external fun nativeSendMessageAsync(
conversationPointer: Long,
messageJsonString: String,
extraContextJsonString: String,
callback: JniMessageCallback,
)

Expand All @@ -225,7 +226,11 @@ internal object LiteRtLmJni {
* @param messageJsonString The message to be processed by the native conversation instance.
* @return The response message in JSON string format.
*/
external fun nativeSendMessage(conversationPointer: Long, messageJsonString: String): String
external fun nativeSendMessage(
conversationPointer: Long,
messageJsonString: String,
extraContextJsonString: String,
): String

/**
* Cancels the ongoing conversation process.
Expand Down
1 change: 1 addition & 0 deletions kotlin/java/com/google/ai/edge/litertlm/jni/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ cc_binary(
linkshared = 1,
deps = [
"@com_google_absl//absl/base:log_severity",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/log:globals",
Expand Down
39 changes: 34 additions & 5 deletions kotlin/java/com/google/ai/edge/litertlm/jni/litertlm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,19 @@ SamplerParameters CreateSamplerParamsFromJni(JNIEnv* env,

return sampler_params;
}

nlohmann::ordered_json GetExtraContextJson(JNIEnv* env,
jstring extra_context_json_string) {
const char* extra_context_chars =
env->GetStringUTFChars(extra_context_json_string, nullptr);
nlohmann::ordered_json extra_context_json;
if (extra_context_chars != nullptr) {
extra_context_json = nlohmann::ordered_json::parse(extra_context_chars);
}
env->ReleaseStringUTFChars(extra_context_json_string, extra_context_chars);
return extra_context_json;
}

} // namespace

extern "C" {
Expand Down Expand Up @@ -851,7 +864,8 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeDeleteConversation)(

LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeSendMessageAsync)(
JNIEnv* env, jclass thiz, jlong conversation_pointer,
jstring messageJSONString, jobject callback) {
jstring messageJSONString, jstring extraContextJsonString,
jobject callback) {
JavaVM* jvm = nullptr;
if (env->GetJavaVM(&jvm) != JNI_OK) {
ThrowLiteRtLmJniException(env, "Failed to get JavaVM");
Expand All @@ -866,6 +880,13 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeSendMessageAsync)(
nlohmann::ordered_json::parse(json_chars);
env->ReleaseStringUTFChars(messageJSONString, json_chars);

litert::lm::OptionalArgs optional_args;
nlohmann::ordered_json extra_context =
GetExtraContextJson(env, extraContextJsonString);
if (!extra_context.is_null() && !extra_context.empty()) {
optional_args.extra_context = extra_context;
}

jobject callback_global = env->NewGlobalRef(callback);
jclass callback_class = env->GetObjectClass(callback_global);
jmethodID on_message_mid =
Expand Down Expand Up @@ -932,8 +953,8 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeSendMessageAsync)(
}
};

auto status =
conversation->SendMessageAsync(json_message, std::move(callback_fn));
auto status = conversation->SendMessageAsync(
json_message, std::move(callback_fn), std::move(optional_args));

if (!status.ok()) {
ThrowLiteRtLmJniException(
Expand All @@ -943,7 +964,7 @@ LITERTLM_JNIEXPORT void JNICALL JNI_METHOD(nativeSendMessageAsync)(

LITERTLM_JNIEXPORT jstring JNICALL JNI_METHOD(nativeSendMessage)(
JNIEnv* env, jclass thiz, jlong conversation_pointer,
jstring messageJSONString) {
jstring messageJSONString, jstring extraContextJsonString) {
Conversation* conversation =
reinterpret_cast<Conversation*>(conversation_pointer);

Expand All @@ -952,7 +973,15 @@ LITERTLM_JNIEXPORT jstring JNICALL JNI_METHOD(nativeSendMessage)(
nlohmann::ordered_json::parse(json_chars);
env->ReleaseStringUTFChars(messageJSONString, json_chars);

auto response = conversation->SendMessage(json_message);
litert::lm::OptionalArgs optional_args;
nlohmann::ordered_json extra_context =
GetExtraContextJson(env, extraContextJsonString);
if (!extra_context.is_null() && !extra_context.empty()) {
optional_args.extra_context = extra_context;
}

auto response =
conversation->SendMessage(json_message, std::move(optional_args));
if (!response.ok()) {
ThrowLiteRtLmJniException(env, "Failed to call nativeSendMessage: " +
response.status().ToString());
Expand Down
2 changes: 1 addition & 1 deletion runtime/conversation/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ cc_library(
"@com_google_absl//absl/time",
"@nlohmann_json//:json",
"//runtime/components:prompt_template",
"//runtime/components:tokenizer",
"//runtime/components/constrained_decoding:constraint",
"//runtime/components/constrained_decoding:constraint_provider",
"//runtime/components/constrained_decoding:constraint_provider_config",
Expand Down Expand Up @@ -137,6 +136,7 @@ cc_test(
":conversation",
":io_types",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
Loading
Loading