Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import ai.koog.prompt.message.LLMChoice
import ai.koog.prompt.message.Message
import ai.koog.prompt.streaming.IncompleteStreamException
import ai.koog.prompt.streaming.StreamFrame
import ai.koog.prompt.structure.json.generator.BasicJsonSchemaGenerator
import ai.koog.prompt.structure.json.generator.StandardJsonSchemaGenerator
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.delay
Expand Down Expand Up @@ -186,6 +188,14 @@ public class RetryingLLMClient @JvmOverloads constructor(
override fun close() {
delegate.close()
}

override fun getStandardJsonSchemaGenerator(): StandardJsonSchemaGenerator {
return delegate.getStandardJsonSchemaGenerator()
}

override fun getBasicJsonSchemaGenerator(): BasicJsonSchemaGenerator {
return delegate.getBasicJsonSchemaGenerator()
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import ai.koog.prompt.streaming.StreamFrame
import ai.koog.prompt.streaming.emitTextDelta
import ai.koog.prompt.streaming.streamFrameFlow
import ai.koog.prompt.streaming.streamFrameFlowOf
import ai.koog.prompt.structure.json.generator.BasicJsonSchemaGenerator
import ai.koog.prompt.structure.json.generator.StandardJsonSchemaGenerator
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.collect
Expand Down Expand Up @@ -440,6 +442,28 @@ class RetryingLLMClientTest {
assertEquals(1, mockClient.streamCalls) // No retry after first frame
}

@Test
fun testBasicJsonSchemaGeneratorDelegation() = runTest {
val mockClient = MockLLMClient()

val retryingClient = RetryingLLMClient(mockClient)

val result = retryingClient.getBasicJsonSchemaGenerator()

assertEquals(mockClient.basicJsonSchemaGeneratorDefault, result)
}

@Test
fun testStandardJsonSchemaGeneratorDelegation() = runTest {
val mockClient = MockLLMClient()

val retryingClient = RetryingLLMClient(mockClient)

val result = retryingClient.getStandardJsonSchemaGenerator()

assertEquals(mockClient.standardJsonSchemaGeneratorDefault, result)
}

// Mock LLMClient for testing
private class MockLLMClient(
private val executeResponse: List<Message.Response> = emptyList(),
Expand All @@ -453,6 +477,9 @@ class RetryingLLMClientTest {
private val llmProvider: LLMProvider = LLMProvider.OpenAI,
) : LLMClient() {

val basicJsonSchemaGeneratorDefault = object : BasicJsonSchemaGenerator() {}
val standardJsonSchemaGeneratorDefault = object : StandardJsonSchemaGenerator() {}

var executeCalls = 0
var streamCalls = 0
var multipleChoicesCalls = 0
Expand Down Expand Up @@ -528,5 +555,13 @@ class RetryingLLMClientTest {
override fun close() {
// No resources to close
}

override fun getBasicJsonSchemaGenerator(): BasicJsonSchemaGenerator {
return basicJsonSchemaGeneratorDefault
}

override fun getStandardJsonSchemaGenerator(): StandardJsonSchemaGenerator {
return standardJsonSchemaGeneratorDefault
}
}
}
Loading