11use std:: {
22 fmt:: Write as _,
33 ops:: { Deref , DerefMut } ,
4- pin:: Pin ,
54 sync:: { Arc , LazyLock } ,
6- task:: Poll ,
75} ;
86
9- use bytes:: Bytes ;
10- use futures:: { FutureExt as _, Stream } ;
7+ use futures:: FutureExt as _;
118use reqwest:: Method ;
129use secrecy:: { ExposeSecret as _, SecretString } ;
13- use serde:: ser:: Error as _;
1410
15- use crate :: { Chat , Error , Result , chat, types} ;
11+ use crate :: { Chat , Error , Result , StreamGenerateContent , chat, types} ;
1612
17- const BASE_URI : & str = "https://generativelanguage.googleapis.com" ;
13+ pub ( crate ) const BASE_URI : & str = "https://generativelanguage.googleapis.com" ;
1814
1915pub struct Route < T > {
20- client : Client ,
21- kind : T ,
16+ pub ( crate ) client : Client ,
17+ pub ( crate ) kind : T ,
2218}
2319
2420impl < T > Route < T > {
@@ -71,46 +67,6 @@ impl DerefMut for Route<GenerateContent> {
7167 }
7268}
7369
74- impl Deref for Route < StreamGenerateContent > {
75- type Target = GenerateContent ;
76-
77- fn deref ( & self ) -> & Self :: Target {
78- & self . kind . 0
79- }
80- }
81-
82- impl DerefMut for Route < StreamGenerateContent > {
83- fn deref_mut ( & mut self ) -> & mut Self :: Target {
84- & mut self . kind . 0
85- }
86- }
87-
88- impl Route < StreamGenerateContent > {
89- pub async fn stream ( self ) -> std:: result:: Result < RouteStream < StreamGenerateContent > , String > {
90- let url = format ! ( "{BASE_URI}/{}" , self ) ;
91- let body = self . kind . body ( ) . clone ( ) ;
92- let mut request = self
93- . client
94- . reqwest
95- . request ( StreamGenerateContent :: METHOD , url) ;
96-
97- if let Some ( body) = body {
98- request = request. json ( & body) ;
99- }
100-
101- let response = request. send ( ) . await . map_err ( |e| e. to_string ( ) ) ?;
102- let stream = response. bytes_stream ( ) ;
103-
104- Ok ( RouteStream {
105- phantom : std:: marker:: PhantomData ,
106- stream : Box :: pin ( stream) ,
107- buffer : Vec :: new ( ) ,
108- pos : 0 ,
109- state : ParseState :: CannotAdvance ,
110- } )
111- }
112- }
113-
11470impl < T : Request > std:: fmt:: Display for Route < T > {
11571 fn fmt ( & self , fmt : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
11672 let mut fmt = Formatter :: new ( fmt) ;
@@ -119,153 +75,10 @@ impl<T: Request> std::fmt::Display for Route<T> {
11975 }
12076}
12177
122- pub struct RouteStream < T > {
123- phantom : std:: marker:: PhantomData < T > ,
124- stream : Pin < Box < dyn Stream < Item = std:: result:: Result < Bytes , reqwest:: Error > > + Send > > ,
125- buffer : Vec < u8 > ,
126- pos : usize , // A cursor into the buffer.
127- state : ParseState ,
128- }
129-
130- #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
131- enum ParseState {
132- CannotAdvance ,
133- ReadingChars ,
134- ReadingValue ,
135- Finished ,
136- }
137-
138- #[ derive( Debug ) ]
139- enum ParseOutcome {
140- Ok ( Option < types:: Response > ) ,
141- Err ( serde_json:: Error ) ,
142- Eof ,
143- }
144-
145- impl RouteStream < StreamGenerateContent > {
146- fn next_char_pos ( & self ) -> Option < usize > {
147- self . buffer [ self . pos ..]
148- . iter ( )
149- . position ( |& b| !b. is_ascii_whitespace ( ) )
150- . map ( |p| self . pos + p)
151- }
152-
153- fn advance_next_char ( & mut self ) -> Option < u8 > {
154- self . pos = self . next_char_pos ( ) . unwrap_or ( self . buffer . len ( ) ) ;
155- self . buffer . get ( self . pos ) . copied ( )
156- }
157-
158- fn current_char ( & self ) -> Option < u8 > {
159- self . buffer . get ( self . pos ) . copied ( )
160- }
161-
162- fn is_bridge_char ( & self ) -> bool {
163- matches ! ( self . current_char( ) , Some ( b'[' ) | Some ( b',' ) )
164- }
165-
166- fn parse_chunk ( & mut self ) -> ParseOutcome {
167- let mut de = serde_json:: Deserializer :: from_slice ( & self . buffer [ self . pos ..] )
168- . into_iter :: < types:: Response > ( ) ;
169- match de. next ( ) {
170- Some ( Ok ( value) ) => {
171- self . pos += de. byte_offset ( ) ;
172- ParseOutcome :: Ok ( Some ( value) )
173- }
174- Some ( Err ( e) ) if e. is_eof ( ) => ParseOutcome :: Eof ,
175- Some ( Err ( e) ) => ParseOutcome :: Err ( e) ,
176- None => ParseOutcome :: Ok ( None ) , // No more objects to read.
177- }
178- }
179-
180- fn try_parse_next ( & mut self ) -> Option < ParseOutcome > {
181- match self . state {
182- ParseState :: CannotAdvance => None , // nothing to read
183- ParseState :: ReadingChars => {
184- self . advance_next_char ( ) ;
185- if self . is_bridge_char ( ) {
186- self . pos += 1 ; // Move past this '[' or ','
187- self . state = ParseState :: ReadingValue ;
188- None
189- } else if let Some ( b']' ) = self . current_char ( ) {
190- // If we hit a ']', we can finish reading.
191- self . state = ParseState :: Finished ;
192- Some ( ParseOutcome :: Ok ( None ) )
193- } else {
194- None
195- }
196- }
197- ParseState :: ReadingValue => {
198- self . advance_next_char ( ) ;
199- // Deserialize one object from our current position.
200- let outcome = self . parse_chunk ( ) ;
201- match & outcome {
202- ParseOutcome :: Ok ( Some ( _) ) => {
203- self . state = ParseState :: ReadingChars ;
204- }
205- ParseOutcome :: Ok ( None ) | ParseOutcome :: Err ( _) => {
206- self . state = ParseState :: Finished ;
207- }
208- ParseOutcome :: Eof => { }
209- } ;
210- Some ( outcome)
211- }
212- ParseState :: Finished => None ,
213- }
214- }
215- }
216-
217- impl Stream for RouteStream < StreamGenerateContent > {
218- type Item = Result < types:: Response > ;
219-
220- fn poll_next (
221- mut self : Pin < & mut Self > ,
222- cx : & mut std:: task:: Context < ' _ > ,
223- ) -> Poll < Option < Self :: Item > > {
224- loop {
225- // Housekeeping: drain the buffer if we've processed a lot.
226- if self . pos > 2048 {
227- let this_pos = self . pos ;
228- self . buffer . drain ( ..this_pos) ;
229- self . pos = 0 ;
230- }
231-
232- if let Some ( ParseOutcome :: Ok ( Some ( response) ) ) = self . try_parse_next ( ) {
233- return Poll :: Ready ( Some ( Ok ( response) ) ) ;
234- }
235-
236- // If we fell through, we need more data. Poll the underlying stream.
237- match self . stream . as_mut ( ) . poll_next ( cx) {
238- Poll :: Ready ( Some ( Ok ( bytes) ) ) => {
239- if self . buffer . is_empty ( ) && !bytes. is_empty ( ) {
240- self . state = ParseState :: ReadingChars ;
241- }
242- self . buffer . extend_from_slice ( & bytes) ;
243- continue ; // Loop again to process new data.
244- }
245- Poll :: Pending => return Poll :: Pending ,
246- Poll :: Ready ( Some ( Err ( e) ) ) => {
247- self . state = ParseState :: Finished ;
248- return Poll :: Ready ( Some ( Err ( Error :: Http ( e) ) ) ) ;
249- }
250- Poll :: Ready ( None ) => {
251- // Underlying stream ended. Check if we're in a clean state.
252- if self . state != ParseState :: Finished && self . pos < self . buffer . len ( ) {
253- let msg =
254- format ! ( "stream ended with unparsed data in state {:?}" , self . state) ;
255- return Poll :: Ready ( Some ( Err ( serde_json:: Error :: custom ( msg) . into ( ) ) ) ) ;
256- }
257- self . state = ParseState :: Finished ;
258- return Poll :: Ready ( None ) ;
259- }
260- }
261- }
262- }
263- }
264-
26578/// Covers the 20% of use cases that [Chat] doesn't
26679#[ derive( Clone ) ]
26780pub struct Client {
268- inner : Arc < ClientInner > ,
81+ pub ( crate ) inner : Arc < ClientInner > ,
26982}
27083
27184impl Deref for Client {
@@ -304,10 +117,7 @@ impl Client {
304117 }
305118
306119 pub fn stream_generate_content ( & self , model : & str ) -> Route < StreamGenerateContent > {
307- Route :: new (
308- self ,
309- StreamGenerateContent ( GenerateContent :: new ( model. into ( ) ) ) ,
310- )
120+ Route :: new ( self , StreamGenerateContent :: new ( model) )
311121 }
312122
313123 pub fn instance ( ) -> Client {
@@ -317,7 +127,7 @@ impl Client {
317127}
318128
319129pub struct GenerateContent {
320- model : Box < str > ,
130+ pub ( crate ) model : Box < str > ,
321131 pub body : types:: GenerateContent ,
322132}
323133
@@ -380,34 +190,6 @@ impl Request for GenerateContent {
380190 }
381191}
382192
383- pub struct StreamGenerateContent ( GenerateContent ) ;
384-
385- impl Request for StreamGenerateContent {
386- type Model = types:: Response ;
387- type Body = types:: GenerateContent ;
388-
389- const METHOD : Method = Method :: POST ;
390-
391- fn format_uri ( & self , fmt : & mut Formatter < ' _ , ' _ > ) -> std:: fmt:: Result {
392- fmt. write_str ( "v1beta/" ) ?;
393- fmt. write_str ( "models/" ) ?;
394- fmt. write_str ( & self . 0 . model ) ?;
395- fmt. write_str ( ":streamGenerateContent" )
396- }
397-
398- fn body ( & self ) -> Option < Self :: Body > {
399- Some ( self . 0 . body . clone ( ) )
400- }
401- }
402-
403- impl std:: fmt:: Display for StreamGenerateContent {
404- fn fmt ( & self , fmt : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
405- let mut fmt = Formatter :: new ( fmt) ;
406- self . format_uri ( & mut fmt) ?;
407- fmt. write_query_param ( "key" , & self . 0 . model )
408- }
409- }
410-
411193#[ derive( Default ) ]
412194pub struct Models {
413195 page_size : Option < usize > ,
@@ -458,14 +240,18 @@ impl DerefMut for Formatter<'_, '_> {
458240}
459241
460242impl < ' me , ' buffer > Formatter < ' me , ' buffer > {
461- fn new ( formatter : & ' me mut std:: fmt:: Formatter < ' buffer > ) -> Self {
243+ pub ( crate ) fn new ( formatter : & ' me mut std:: fmt:: Formatter < ' buffer > ) -> Self {
462244 Self {
463245 formatter,
464246 is_first : true ,
465247 }
466248 }
467249
468- fn write_query_param ( & mut self , key : & str , value : & impl std:: fmt:: Display ) -> std:: fmt:: Result {
250+ pub ( crate ) fn write_query_param (
251+ & mut self ,
252+ key : & str ,
253+ value : & impl std:: fmt:: Display ,
254+ ) -> std:: fmt:: Result {
469255 if self . is_first {
470256 self . formatter . write_char ( '?' ) ?;
471257 self . is_first = false ;
@@ -492,7 +278,7 @@ impl<'me, 'buffer> Formatter<'me, 'buffer> {
492278}
493279
494280pub struct ClientInner {
495- reqwest : reqwest:: Client ,
281+ pub ( crate ) reqwest : reqwest:: Client ,
496282 key : SecretString ,
497283}
498284
0 commit comments