Skip to content

Commit 8049af6

Browse files
advaitjaincopybara-github
authored andcommitted
Add explicit cancellation logic to OpenAI streaming responses in litert-lm serve.
LiteRT-LM-PiperOrigin-RevId: 908385328
1 parent 8cf2128 commit 8049af6

2 files changed

Lines changed: 59 additions & 4 deletions

File tree

python/litert_lm_cli/serve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,6 @@ def do_POST(self) -> None: # pylint: disable=invalid-name
386386
)
387387
return
388388

389-
# TODO: b/507147993 - Handle client early disconnects robustly.
390389
# Handle streaming response using Server-Sent Events (SSE).
391390
# We send response.created, response.output_text.delta, and
392391
# response.completed events.
@@ -433,6 +432,7 @@ def do_POST(self) -> None: # pylint: disable=invalid-name
433432
self.wfile.flush()
434433
except Exception as e:
435434
click.echo(click.style(f"Error during streaming: {e!r}", fg="red"))
435+
conv.cancel_process()
436436
try:
437437
self.wfile.write(
438438
"event: response.error\ndata:"

python/litert_lm_cli/serve_openai_streaming_test.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import collections.abc
2+
import http.client
23
import http.server
34
import json
45
import pathlib
6+
import socket
57
import threading
68
from unittest import mock
79
import urllib.request
@@ -55,9 +57,6 @@ def tearDown(self):
5557
super().tearDown()
5658

5759
def test_openai_responses_streaming(self):
58-
self.assertTrue(
59-
self.model_path.exists(), f"Model not found at {self.model_path}"
60-
)
6160

6261
mock_from_id = self.enter_context(
6362
mock.patch.object(model.Model, "from_model_id", autospec=True)
@@ -114,6 +113,62 @@ def test_openai_responses_streaming(self):
114113
with self.subTest(name="Verify DONE message"):
115114
self.assertIn("data: [DONE]", lines)
116115

116+
def test_openai_responses_streaming_client_disconnect(self):
117+
118+
mock_from_id = self.enter_context(
119+
mock.patch.object(model.Model, "from_model_id", autospec=True)
120+
)
121+
mock_from_id.return_value = model.Model(
122+
model_id="gemma3", model_path=str(self.model_path)
123+
)
124+
125+
data = json.dumps(
126+
{"model": "gemma3", "input": "Count to 50", "stream": True}
127+
).encode("utf-8")
128+
129+
req = urllib.request.Request(
130+
f"http://localhost:{self.port}/v1/responses",
131+
data=data,
132+
headers={"Content-Type": "application/json"},
133+
)
134+
135+
response = urllib.request.urlopen(req, timeout=60)
136+
self.assertEqual(response.getcode(), 200)
137+
138+
for line in response:
139+
line_str = line.decode("utf-8")
140+
if line_str.startswith("event: response.output_text.delta"):
141+
data_line = next(response).decode("utf-8")
142+
self.assertStartsWith(data_line, "data: ")
143+
break
144+
else:
145+
self.fail("Stream ended early without delta event")
146+
147+
# This tests a scenario where a client makes a request and exits before the
148+
# response is completed. Note: this assumes prefill is already complete.
149+
# TODO: b/508348544 - There are other scenarios where a client can cause the
150+
# server to hang.
151+
response.close()
152+
153+
conn = http.client.HTTPConnection("localhost", self.port, timeout=15)
154+
try:
155+
conn.request(
156+
"POST",
157+
"/v1/responses",
158+
body=json.dumps({"model": "gemma3", "input": "Hi"}).encode("utf-8"),
159+
headers={"Content-Type": "application/json"},
160+
)
161+
try:
162+
response2 = conn.getresponse()
163+
except Exception as e:
164+
self.fail(f"Second request failed (timed out as expected?): {e!r}")
165+
166+
self.assertEqual(response2.status, 200)
167+
res_body2 = json.loads(response2.read().decode("utf-8"))
168+
self.assertIn("id", res_body2)
169+
finally:
170+
conn.close()
171+
117172

118173
if __name__ == "__main__":
119174
absltest.main()

0 commit comments

Comments
 (0)