Skip to content

Commit 5c19c5a

Browse files
committed
refactor(stream): move streaming logic to dedicated module
* Move `StreamGenerateContent` and `RouteStream` types * Move parsing logic for streamed responses * Update imports in `client.rs` and `lib.rs` * Adjust visibility of related fields/methods for inter-module access
1 parent e685400 commit 5c19c5a

3 files changed

Lines changed: 259 additions & 230 deletions

File tree

src/client.rs

Lines changed: 15 additions & 229 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,20 @@
11
use std::{
22
fmt::Write as _,
33
ops::{Deref, DerefMut},
4-
pin::Pin,
54
sync::{Arc, LazyLock},
6-
task::Poll,
75
};
86

9-
use bytes::Bytes;
10-
use futures::{FutureExt as _, Stream};
7+
use futures::FutureExt as _;
118
use reqwest::Method;
129
use secrecy::{ExposeSecret as _, SecretString};
13-
use serde::ser::Error as _;
1410

15-
use crate::{Chat, Error, Result, chat, types};
11+
use crate::{Chat, Error, Result, StreamGenerateContent, chat, types};
1612

17-
const BASE_URI: &str = "https://generativelanguage.googleapis.com";
13+
pub(crate) const BASE_URI: &str = "https://generativelanguage.googleapis.com";
1814

1915
pub struct Route<T> {
20-
client: Client,
21-
kind: T,
16+
pub(crate) client: Client,
17+
pub(crate) kind: T,
2218
}
2319

2420
impl<T> Route<T> {
@@ -71,46 +67,6 @@ impl DerefMut for Route<GenerateContent> {
7167
}
7268
}
7369

74-
impl Deref for Route<StreamGenerateContent> {
75-
type Target = GenerateContent;
76-
77-
fn deref(&self) -> &Self::Target {
78-
&self.kind.0
79-
}
80-
}
81-
82-
impl DerefMut for Route<StreamGenerateContent> {
83-
fn deref_mut(&mut self) -> &mut Self::Target {
84-
&mut self.kind.0
85-
}
86-
}
87-
88-
impl Route<StreamGenerateContent> {
89-
pub async fn stream(self) -> std::result::Result<RouteStream<StreamGenerateContent>, String> {
90-
let url = format!("{BASE_URI}/{}", self);
91-
let body = self.kind.body().clone();
92-
let mut request = self
93-
.client
94-
.reqwest
95-
.request(StreamGenerateContent::METHOD, url);
96-
97-
if let Some(body) = body {
98-
request = request.json(&body);
99-
}
100-
101-
let response = request.send().await.map_err(|e| e.to_string())?;
102-
let stream = response.bytes_stream();
103-
104-
Ok(RouteStream {
105-
phantom: std::marker::PhantomData,
106-
stream: Box::pin(stream),
107-
buffer: Vec::new(),
108-
pos: 0,
109-
state: ParseState::CannotAdvance,
110-
})
111-
}
112-
}
113-
11470
impl<T: Request> std::fmt::Display for Route<T> {
11571
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
11672
let mut fmt = Formatter::new(fmt);
@@ -119,153 +75,10 @@ impl<T: Request> std::fmt::Display for Route<T> {
11975
}
12076
}
12177

122-
pub struct RouteStream<T> {
123-
phantom: std::marker::PhantomData<T>,
124-
stream: Pin<Box<dyn Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send>>,
125-
buffer: Vec<u8>,
126-
pos: usize, // A cursor into the buffer.
127-
state: ParseState,
128-
}
129-
130-
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131-
enum ParseState {
132-
CannotAdvance,
133-
ReadingChars,
134-
ReadingValue,
135-
Finished,
136-
}
137-
138-
#[derive(Debug)]
139-
enum ParseOutcome {
140-
Ok(Option<types::Response>),
141-
Err(serde_json::Error),
142-
Eof,
143-
}
144-
145-
impl RouteStream<StreamGenerateContent> {
146-
fn next_char_pos(&self) -> Option<usize> {
147-
self.buffer[self.pos..]
148-
.iter()
149-
.position(|&b| !b.is_ascii_whitespace())
150-
.map(|p| self.pos + p)
151-
}
152-
153-
fn advance_next_char(&mut self) -> Option<u8> {
154-
self.pos = self.next_char_pos().unwrap_or(self.buffer.len());
155-
self.buffer.get(self.pos).copied()
156-
}
157-
158-
fn current_char(&self) -> Option<u8> {
159-
self.buffer.get(self.pos).copied()
160-
}
161-
162-
fn is_bridge_char(&self) -> bool {
163-
matches!(self.current_char(), Some(b'[') | Some(b','))
164-
}
165-
166-
fn parse_chunk(&mut self) -> ParseOutcome {
167-
let mut de = serde_json::Deserializer::from_slice(&self.buffer[self.pos..])
168-
.into_iter::<types::Response>();
169-
match de.next() {
170-
Some(Ok(value)) => {
171-
self.pos += de.byte_offset();
172-
ParseOutcome::Ok(Some(value))
173-
}
174-
Some(Err(e)) if e.is_eof() => ParseOutcome::Eof,
175-
Some(Err(e)) => ParseOutcome::Err(e),
176-
None => ParseOutcome::Ok(None), // No more objects to read.
177-
}
178-
}
179-
180-
fn try_parse_next(&mut self) -> Option<ParseOutcome> {
181-
match self.state {
182-
ParseState::CannotAdvance => None, // nothing to read
183-
ParseState::ReadingChars => {
184-
self.advance_next_char();
185-
if self.is_bridge_char() {
186-
self.pos += 1; // Move past this '[' or ','
187-
self.state = ParseState::ReadingValue;
188-
None
189-
} else if let Some(b']') = self.current_char() {
190-
// If we hit a ']', we can finish reading.
191-
self.state = ParseState::Finished;
192-
Some(ParseOutcome::Ok(None))
193-
} else {
194-
None
195-
}
196-
}
197-
ParseState::ReadingValue => {
198-
self.advance_next_char();
199-
// Deserialize one object from our current position.
200-
let outcome = self.parse_chunk();
201-
match &outcome {
202-
ParseOutcome::Ok(Some(_)) => {
203-
self.state = ParseState::ReadingChars;
204-
}
205-
ParseOutcome::Ok(None) | ParseOutcome::Err(_) => {
206-
self.state = ParseState::Finished;
207-
}
208-
ParseOutcome::Eof => {}
209-
};
210-
Some(outcome)
211-
}
212-
ParseState::Finished => None,
213-
}
214-
}
215-
}
216-
217-
impl Stream for RouteStream<StreamGenerateContent> {
218-
type Item = Result<types::Response>;
219-
220-
fn poll_next(
221-
mut self: Pin<&mut Self>,
222-
cx: &mut std::task::Context<'_>,
223-
) -> Poll<Option<Self::Item>> {
224-
loop {
225-
// Housekeeping: drain the buffer if we've processed a lot.
226-
if self.pos > 2048 {
227-
let this_pos = self.pos;
228-
self.buffer.drain(..this_pos);
229-
self.pos = 0;
230-
}
231-
232-
if let Some(ParseOutcome::Ok(Some(response))) = self.try_parse_next() {
233-
return Poll::Ready(Some(Ok(response)));
234-
}
235-
236-
// If we fell through, we need more data. Poll the underlying stream.
237-
match self.stream.as_mut().poll_next(cx) {
238-
Poll::Ready(Some(Ok(bytes))) => {
239-
if self.buffer.is_empty() && !bytes.is_empty() {
240-
self.state = ParseState::ReadingChars;
241-
}
242-
self.buffer.extend_from_slice(&bytes);
243-
continue; // Loop again to process new data.
244-
}
245-
Poll::Pending => return Poll::Pending,
246-
Poll::Ready(Some(Err(e))) => {
247-
self.state = ParseState::Finished;
248-
return Poll::Ready(Some(Err(Error::Http(e))));
249-
}
250-
Poll::Ready(None) => {
251-
// Underlying stream ended. Check if we're in a clean state.
252-
if self.state != ParseState::Finished && self.pos < self.buffer.len() {
253-
let msg =
254-
format!("stream ended with unparsed data in state {:?}", self.state);
255-
return Poll::Ready(Some(Err(serde_json::Error::custom(msg).into())));
256-
}
257-
self.state = ParseState::Finished;
258-
return Poll::Ready(None);
259-
}
260-
}
261-
}
262-
}
263-
}
264-
26578
/// Covers the 20% of use cases that [Chat] doesn't
26679
#[derive(Clone)]
26780
pub struct Client {
268-
inner: Arc<ClientInner>,
81+
pub(crate) inner: Arc<ClientInner>,
26982
}
27083

27184
impl Deref for Client {
@@ -304,10 +117,7 @@ impl Client {
304117
}
305118

306119
pub fn stream_generate_content(&self, model: &str) -> Route<StreamGenerateContent> {
307-
Route::new(
308-
self,
309-
StreamGenerateContent(GenerateContent::new(model.into())),
310-
)
120+
Route::new(self, StreamGenerateContent::new(model))
311121
}
312122

313123
pub fn instance() -> Client {
@@ -317,7 +127,7 @@ impl Client {
317127
}
318128

319129
pub struct GenerateContent {
320-
model: Box<str>,
130+
pub(crate) model: Box<str>,
321131
pub body: types::GenerateContent,
322132
}
323133

@@ -380,34 +190,6 @@ impl Request for GenerateContent {
380190
}
381191
}
382192

383-
pub struct StreamGenerateContent(GenerateContent);
384-
385-
impl Request for StreamGenerateContent {
386-
type Model = types::Response;
387-
type Body = types::GenerateContent;
388-
389-
const METHOD: Method = Method::POST;
390-
391-
fn format_uri(&self, fmt: &mut Formatter<'_, '_>) -> std::fmt::Result {
392-
fmt.write_str("v1beta/")?;
393-
fmt.write_str("models/")?;
394-
fmt.write_str(&self.0.model)?;
395-
fmt.write_str(":streamGenerateContent")
396-
}
397-
398-
fn body(&self) -> Option<Self::Body> {
399-
Some(self.0.body.clone())
400-
}
401-
}
402-
403-
impl std::fmt::Display for StreamGenerateContent {
404-
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
405-
let mut fmt = Formatter::new(fmt);
406-
self.format_uri(&mut fmt)?;
407-
fmt.write_query_param("key", &self.0.model)
408-
}
409-
}
410-
411193
#[derive(Default)]
412194
pub struct Models {
413195
page_size: Option<usize>,
@@ -458,14 +240,18 @@ impl DerefMut for Formatter<'_, '_> {
458240
}
459241

460242
impl<'me, 'buffer> Formatter<'me, 'buffer> {
461-
fn new(formatter: &'me mut std::fmt::Formatter<'buffer>) -> Self {
243+
pub(crate) fn new(formatter: &'me mut std::fmt::Formatter<'buffer>) -> Self {
462244
Self {
463245
formatter,
464246
is_first: true,
465247
}
466248
}
467249

468-
fn write_query_param(&mut self, key: &str, value: &impl std::fmt::Display) -> std::fmt::Result {
250+
pub(crate) fn write_query_param(
251+
&mut self,
252+
key: &str,
253+
value: &impl std::fmt::Display,
254+
) -> std::fmt::Result {
469255
if self.is_first {
470256
self.formatter.write_char('?')?;
471257
self.is_first = false;
@@ -492,7 +278,7 @@ impl<'me, 'buffer> Formatter<'me, 'buffer> {
492278
}
493279

494280
pub struct ClientInner {
495-
reqwest: reqwest::Client,
281+
pub(crate) reqwest: reqwest::Client,
496282
key: SecretString,
497283
}
498284

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,15 @@
4848
mod chat;
4949
mod client;
5050
mod error;
51+
mod stream;
5152
pub mod types;
5253

5354
pub type Result<T> = std::result::Result<T, Error>;
5455

5556
pub use chat::Chat;
5657
pub use client::Client;
57-
pub use client::{RouteStream, StreamGenerateContent};
5858
pub use error::Error;
59+
pub use stream::{RouteStream, StreamGenerateContent};
5960

6061
/// Creates a new Gemini client instance using the default configuration.
6162
pub fn client() -> Client {

0 commit comments

Comments
 (0)