Skip to content

Commit 47f00ee

Browse files
authored
Merge pull request #15 from dev-msp/stream-response
Stream generation responses
2 parents 4158fef + 5c19c5a commit 47f00ee

6 files changed

Lines changed: 314 additions & 33 deletions

File tree

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ readme = "README.md"
1414
keywords = ["ai", "google", "gemini"]
1515

1616
[dependencies]
17+
bytes = "1.10.1"
1718
futures = "0.3"
18-
reqwest = { version = "0.12.12", features = ["json", "rustls-tls"] }
19+
reqwest = { version = "0.12.19", features = ["stream", "json", "rustls-tls"] }
1920
secrecy = "0.10"
2021
serde = { version = "1.0", features = ["derive"] }
2122
serde_json = "1.0"

examples/streaming.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
use futures::StreamExt as _;
2+
use gemini_rs::types::{CodeExecutionTool, Tools};
3+
4+
#[tokio::main]
5+
async fn main() -> Result<(), Box<dyn std::error::Error>> {
6+
let client = gemini_rs::Client::instance();
7+
let tools = vec![Tools {
8+
function_declarations: None,
9+
google_search: None,
10+
code_execution: Some(CodeExecutionTool {}),
11+
}];
12+
let mut req = client.stream_generate_content("gemini-2.5-flash-preview-05-20");
13+
req.message("what's the sum of the prime numbers between 1 and 100?");
14+
req.tools(tools);
15+
16+
let stream = req.stream().await?;
17+
println!("Stream started...");
18+
19+
stream
20+
.for_each(|chunk| async move {
21+
match chunk {
22+
Ok(chunk) => println!("Chunk: {:?}", chunk.candidates),
23+
Err(e) => eprintln!("Error in stream: {:?}", e),
24+
}
25+
})
26+
.await;
27+
28+
Ok(())
29+
}

src/client.rs

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@ use std::{
44
sync::{Arc, LazyLock},
55
};
66

7-
use crate::{Chat, Error, Result, chat, types};
87
use futures::FutureExt as _;
98
use reqwest::Method;
109
use secrecy::{ExposeSecret as _, SecretString};
11-
use serde::Serialize;
1210

13-
const BASE_URI: &str = "https://generativelanguage.googleapis.com";
11+
use crate::{Chat, Error, Result, StreamGenerateContent, chat, types};
12+
13+
pub(crate) const BASE_URI: &str = "https://generativelanguage.googleapis.com";
1414

1515
pub struct Route<T> {
16-
client: Client,
17-
kind: T,
16+
pub(crate) client: Client,
17+
pub(crate) kind: T,
1818
}
1919

2020
impl<T> Route<T> {
@@ -38,16 +38,11 @@ impl<T: Request> IntoFuture for Route<T> {
3838
.request(T::METHOD, format!("{BASE_URI}/{self}"));
3939

4040
if let Some(body) = self.kind.body() {
41-
// Debug print the request body
42-
if let Ok(body_json) = serde_json::to_string_pretty(&body) {
43-
println!("Request body: {body_json}");
44-
}
4541
request = request.json(&body);
4642
};
4743

4844
let response = request.send().await?;
4945
let raw_json = response.text().await?;
50-
println!("Response: {raw_json}");
5146

5247
match serde_json::from_str::<types::ApiResponse<T::Model>>(&raw_json)? {
5348
types::ApiResponse::Ok(response) => Ok(response),
@@ -58,15 +53,15 @@ impl<T: Request> IntoFuture for Route<T> {
5853
}
5954
}
6055

61-
impl<T> Deref for Route<T> {
62-
type Target = T;
56+
impl Deref for Route<GenerateContent> {
57+
type Target = GenerateContent;
6358

6459
fn deref(&self) -> &Self::Target {
6560
&self.kind
6661
}
6762
}
6863

69-
impl<T> DerefMut for Route<T> {
64+
impl DerefMut for Route<GenerateContent> {
7065
fn deref_mut(&mut self) -> &mut Self::Target {
7166
&mut self.kind
7267
}
@@ -83,7 +78,7 @@ impl<T: Request> std::fmt::Display for Route<T> {
8378
/// Covers the 20% of use cases that [Chat] doesn't
8479
#[derive(Clone)]
8580
pub struct Client {
86-
inner: Arc<ClientInner>,
81+
pub(crate) inner: Arc<ClientInner>,
8782
}
8883

8984
impl Deref for Client {
@@ -121,14 +116,18 @@ impl Client {
121116
Route::new(self, GenerateContent::new(model.into()))
122117
}
123118

119+
pub fn stream_generate_content(&self, model: &str) -> Route<StreamGenerateContent> {
120+
Route::new(self, StreamGenerateContent::new(model))
121+
}
122+
124123
pub fn instance() -> Client {
125124
static STATIC_INSTANCE: LazyLock<Client> = LazyLock::new(Client::default);
126125
STATIC_INSTANCE.clone()
127126
}
128127
}
129128

130129
pub struct GenerateContent {
131-
model: Box<str>,
130+
pub(crate) model: Box<str>,
132131
pub body: types::GenerateContent,
133132
}
134133

@@ -186,8 +185,8 @@ impl Request for GenerateContent {
186185
fmt.write_str(":generateContent")
187186
}
188187

189-
fn body(self) -> Option<Self::Body> {
190-
Some(self.body)
188+
fn body(&self) -> Option<Self::Body> {
189+
Some(self.body.clone())
191190
}
192191
}
193192

@@ -241,14 +240,18 @@ impl DerefMut for Formatter<'_, '_> {
241240
}
242241

243242
impl<'me, 'buffer> Formatter<'me, 'buffer> {
244-
fn new(formatter: &'me mut std::fmt::Formatter<'buffer>) -> Self {
243+
pub(crate) fn new(formatter: &'me mut std::fmt::Formatter<'buffer>) -> Self {
245244
Self {
246245
formatter,
247246
is_first: true,
248247
}
249248
}
250249

251-
fn write_query_param(&mut self, key: &str, value: &impl std::fmt::Display) -> std::fmt::Result {
250+
pub(crate) fn write_query_param(
251+
&mut self,
252+
key: &str,
253+
value: &impl std::fmt::Display,
254+
) -> std::fmt::Result {
252255
if self.is_first {
253256
self.formatter.write_char('?')?;
254257
self.is_first = false;
@@ -275,7 +278,7 @@ impl<'me, 'buffer> Formatter<'me, 'buffer> {
275278
}
276279

277280
pub struct ClientInner {
278-
reqwest: reqwest::Client,
281+
pub(crate) reqwest: reqwest::Client,
279282
key: SecretString,
280283
}
281284

@@ -299,7 +302,7 @@ pub trait Request: Send + Sized + 'static {
299302

300303
fn format_uri(&self, fmt: &mut Formatter<'_, '_>) -> std::fmt::Result;
301304

302-
fn body(self) -> Option<Self::Body> {
305+
fn body(&self) -> Option<Self::Body> {
303306
None
304307
}
305308
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,15 @@
4848
mod chat;
4949
mod client;
5050
mod error;
51+
mod stream;
5152
pub mod types;
5253

5354
pub type Result<T> = std::result::Result<T, Error>;
5455

5556
pub use chat::Chat;
5657
pub use client::Client;
5758
pub use error::Error;
59+
pub use stream::{RouteStream, StreamGenerateContent};
5860

5961
/// Creates a new Gemini client instance using the default configuration.
6062
pub fn client() -> Client {

0 commit comments

Comments
 (0)