|
1 | 1 | import collections.abc |
| 2 | +import http.client |
2 | 3 | import http.server |
3 | 4 | import json |
4 | 5 | import pathlib |
| 6 | +import socket |
5 | 7 | import threading |
6 | 8 | from unittest import mock |
7 | 9 | import urllib.request |
@@ -55,9 +57,6 @@ def tearDown(self): |
55 | 57 | super().tearDown() |
56 | 58 |
|
57 | 59 | def test_openai_responses_streaming(self): |
58 | | - self.assertTrue( |
59 | | - self.model_path.exists(), f"Model not found at {self.model_path}" |
60 | | - ) |
61 | 60 |
|
62 | 61 | mock_from_id = self.enter_context( |
63 | 62 | mock.patch.object(model.Model, "from_model_id", autospec=True) |
@@ -114,6 +113,62 @@ def test_openai_responses_streaming(self): |
114 | 113 | with self.subTest(name="Verify DONE message"): |
115 | 114 | self.assertIn("data: [DONE]", lines) |
116 | 115 |
|
| 116 | + def test_openai_responses_streaming_client_disconnect(self): |
| 117 | + |
| 118 | + mock_from_id = self.enter_context( |
| 119 | + mock.patch.object(model.Model, "from_model_id", autospec=True) |
| 120 | + ) |
| 121 | + mock_from_id.return_value = model.Model( |
| 122 | + model_id="gemma3", model_path=str(self.model_path) |
| 123 | + ) |
| 124 | + |
| 125 | + data = json.dumps( |
| 126 | + {"model": "gemma3", "input": "Count to 50", "stream": True} |
| 127 | + ).encode("utf-8") |
| 128 | + |
| 129 | + req = urllib.request.Request( |
| 130 | + f"http://localhost:{self.port}/v1/responses", |
| 131 | + data=data, |
| 132 | + headers={"Content-Type": "application/json"}, |
| 133 | + ) |
| 134 | + |
| 135 | + response = urllib.request.urlopen(req, timeout=60) |
| 136 | + self.assertEqual(response.getcode(), 200) |
| 137 | + |
| 138 | + for line in response: |
| 139 | + line_str = line.decode("utf-8") |
| 140 | + if line_str.startswith("event: response.output_text.delta"): |
| 141 | + data_line = next(response).decode("utf-8") |
| 142 | + self.assertStartsWith(data_line, "data: ") |
| 143 | + break |
| 144 | + else: |
| 145 | + self.fail("Stream ended early without delta event") |
| 146 | + |
| 147 | + # This tests a scenario where a client makes a request and exits before the |
| 148 | + # response is completed. Note: this assumes prefill is already complete. |
| 149 | + # TODO: b/508348544 - There are other scenarios where a client can cause the |
| 150 | + # server to hang. |
| 151 | + response.close() |
| 152 | + |
| 153 | + conn = http.client.HTTPConnection("localhost", self.port, timeout=15) |
| 154 | + try: |
| 155 | + conn.request( |
| 156 | + "POST", |
| 157 | + "/v1/responses", |
| 158 | + body=json.dumps({"model": "gemma3", "input": "Hi"}).encode("utf-8"), |
| 159 | + headers={"Content-Type": "application/json"}, |
| 160 | + ) |
| 161 | + try: |
| 162 | + response2 = conn.getresponse() |
| 163 | + except Exception as e: |
| 164 | + self.fail(f"Second request failed (timed out as expected?): {e!r}") |
| 165 | + |
| 166 | + self.assertEqual(response2.status, 200) |
| 167 | + res_body2 = json.loads(response2.read().decode("utf-8")) |
| 168 | + self.assertIn("id", res_body2) |
| 169 | + finally: |
| 170 | + conn.close() |
| 171 | + |
117 | 172 |
|
118 | 173 | if __name__ == "__main__": |
119 | 174 | absltest.main() |
0 commit comments