Skip to content

Commit de7ee29

Browse files
committed
feat(grpc): add server-side stream validators
Introduces `ServerSendStreamValidator` and `ServerRecvStreamValidator` to enforce strict gRPC semantics on server streams. These are intended to wrap raw transport streams before passing them to application-level handlers, ensuring that user-provided service logic cannot violate protocol rules. - `ServerSendStreamValidator` ensures proper response sequencing (Headers -> Messages -> Trailers) and prevents invalid state transitions, while fully supporting trailers-only responses. - `ServerRecvStreamValidator` safely manages terminal states, ensuring that any subsequent polls after stream completion consistently return an error to prevent undefined behavior. - Adds comprehensive unit tests to verify state machine correctness and protocol compliance.
1 parent 368361b commit de7ee29

2 files changed

Lines changed: 385 additions & 0 deletions

File tree

grpc/src/server/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ use crate::core::ServerResponseStreamItem;
3232
use crate::core::Trailers;
3333
use tokio::sync::oneshot;
3434

35+
pub mod stream_util;
36+
3537
pub struct Server {
3638
handler: Option<Arc<dyn DynHandle>>,
3739
}

grpc/src/server/stream_util.rs

Lines changed: 383 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,383 @@
1+
/*
2+
*
3+
* Copyright 2026 gRPC authors.
4+
*
5+
* Permission is hereby granted, free of charge, to any person obtaining a copy
6+
* of this software and associated documentation files (the "Software"), to
7+
* deal in the Software without restriction, including without limitation the
8+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
9+
* sell copies of the Software, and to permit persons to whom the Software is
10+
* furnished to do so, subject to the following conditions:
11+
*
12+
* The above copyright notice and this permission notice shall be included in
13+
* all copies or substantial portions of the Software.
14+
*
15+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20+
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21+
* IN THE SOFTWARE.
22+
*
23+
*/
24+
25+
use crate::core::RecvMessage;
26+
use crate::core::ServerResponseStreamItem;
27+
use crate::server::RecvStream;
28+
use crate::server::SendOptions;
29+
use crate::server::SendStream;
30+
31+
/// Enforces proper gRPC semantics on the server sending stream.
32+
pub struct ServerSendStreamValidator<S> {
33+
inner: S,
34+
state: SendStreamState,
35+
}
36+
37+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38+
enum SendStreamState {
39+
Init,
40+
HeadersSent,
41+
MessagesSent,
42+
Done,
43+
}
44+
45+
impl<S> ServerSendStreamValidator<S>
46+
where
47+
S: SendStream,
48+
{
49+
/// Constructs a new `ServerSendStreamValidator`.
50+
pub fn new(inner: S) -> Self {
51+
Self {
52+
inner,
53+
state: SendStreamState::Init,
54+
}
55+
}
56+
}
57+
58+
impl<S> SendStream for ServerSendStreamValidator<S>
59+
where
60+
S: SendStream,
61+
{
62+
async fn send<'a>(
63+
&mut self,
64+
item: ServerResponseStreamItem<'a>,
65+
options: SendOptions,
66+
) -> Result<(), ()> {
67+
if self.state == SendStreamState::Done {
68+
// Called send after stream completed
69+
return Err(());
70+
}
71+
72+
let next_state = match &item {
73+
ServerResponseStreamItem::Headers(_) => match self.state {
74+
SendStreamState::Init => SendStreamState::HeadersSent,
75+
_ => {
76+
// Received multiple headers frames
77+
self.state = SendStreamState::Done;
78+
return Err(());
79+
}
80+
},
81+
ServerResponseStreamItem::Message(_) => match self.state {
82+
SendStreamState::HeadersSent | SendStreamState::MessagesSent => {
83+
SendStreamState::MessagesSent
84+
}
85+
_ => {
86+
// Sent message before headers or stream completed
87+
self.state = SendStreamState::Done;
88+
return Err(());
89+
}
90+
},
91+
};
92+
93+
let res = self.inner.send(item, options).await;
94+
match res {
95+
Ok(()) => self.state = next_state,
96+
Err(_) => {
97+
// Underlying stream failed to send
98+
self.state = SendStreamState::Done;
99+
}
100+
}
101+
res
102+
}
103+
}
104+
105+
/// Enforces proper gRPC semantics on the server receiving stream.
106+
pub struct ServerRecvStreamValidator<R> {
107+
inner: R,
108+
state: RecvStreamState,
109+
}
110+
111+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112+
enum RecvStreamState {
113+
Init,
114+
Done,
115+
}
116+
117+
impl<R> ServerRecvStreamValidator<R>
118+
where
119+
R: RecvStream,
120+
{
121+
/// Constructs a new `ServerRecvStreamValidator`.
122+
pub fn new(inner: R) -> Self {
123+
Self {
124+
inner,
125+
state: RecvStreamState::Init,
126+
}
127+
}
128+
}
129+
130+
impl<R> RecvStream for ServerRecvStreamValidator<R>
131+
where
132+
R: RecvStream,
133+
{
134+
async fn next(&mut self, msg: &mut dyn RecvMessage) -> Option<Result<(), ()>> {
135+
if self.state == RecvStreamState::Done {
136+
// Called next after stream completed
137+
return Some(Err(()));
138+
}
139+
140+
let res = self.inner.next(msg).await;
141+
142+
match res {
143+
Some(Ok(())) => Some(Ok(())),
144+
None => {
145+
self.state = RecvStreamState::Done;
146+
None
147+
}
148+
Some(Err(())) => {
149+
// Received error from inner stream
150+
self.state = RecvStreamState::Done;
151+
Some(Err(()))
152+
}
153+
}
154+
}
155+
}
156+
157+
#[cfg(test)]
158+
mod tests {
159+
use super::*;
160+
use crate::core::ResponseHeaders;
161+
use crate::core::SendMessage;
162+
use bytes::Buf;
163+
use bytes::Bytes;
164+
165+
impl SendMessage for () {
166+
fn encode(&self) -> Result<Box<dyn Buf + Send + Sync>, String> {
167+
Ok(Box::new(Bytes::new()))
168+
}
169+
}
170+
171+
#[derive(Debug, PartialEq, Eq)]
172+
enum SendEvent {
173+
Headers,
174+
Message,
175+
}
176+
177+
struct MockSendStream {
178+
events: Vec<SendEvent>,
179+
}
180+
181+
impl MockSendStream {
182+
fn new() -> Self {
183+
Self { events: Vec::new() }
184+
}
185+
}
186+
187+
impl SendStream for MockSendStream {
188+
async fn send<'a>(
189+
&mut self,
190+
item: ServerResponseStreamItem<'a>,
191+
_options: SendOptions,
192+
) -> Result<(), ()> {
193+
match item {
194+
ServerResponseStreamItem::Headers(_) => self.events.push(SendEvent::Headers),
195+
ServerResponseStreamItem::Message(_) => self.events.push(SendEvent::Message),
196+
}
197+
Ok(())
198+
}
199+
}
200+
201+
#[tokio::test]
202+
async fn test_send_validator_valid_full_stream() {
203+
let mock = MockSendStream::new();
204+
let mut validator = ServerSendStreamValidator::new(mock);
205+
206+
assert!(
207+
validator
208+
.send(
209+
ServerResponseStreamItem::Headers(ResponseHeaders::default()),
210+
SendOptions::default()
211+
)
212+
.await
213+
.is_ok()
214+
);
215+
assert!(
216+
validator
217+
.send(
218+
ServerResponseStreamItem::Message(&()),
219+
SendOptions::default()
220+
)
221+
.await
222+
.is_ok()
223+
);
224+
225+
assert_eq!(
226+
validator.inner.events,
227+
vec![SendEvent::Headers, SendEvent::Message]
228+
);
229+
}
230+
231+
struct FailingMockSendStream;
232+
233+
impl SendStream for FailingMockSendStream {
234+
async fn send<'a>(
235+
&mut self,
236+
_item: ServerResponseStreamItem<'a>,
237+
_options: SendOptions,
238+
) -> Result<(), ()> {
239+
Err(())
240+
}
241+
}
242+
243+
#[tokio::test]
244+
async fn test_send_validator_invalid_message_before_headers() {
245+
let mock = MockSendStream::new();
246+
let mut validator = ServerSendStreamValidator::new(mock);
247+
248+
assert!(
249+
validator
250+
.send(
251+
ServerResponseStreamItem::Message(&()),
252+
SendOptions::default()
253+
)
254+
.await
255+
.is_err()
256+
);
257+
assert_eq!(validator.state, SendStreamState::Done);
258+
}
259+
260+
#[tokio::test]
261+
async fn test_send_validator_invalid_headers_twice() {
262+
let mock = MockSendStream::new();
263+
let mut validator = ServerSendStreamValidator::new(mock);
264+
265+
assert!(
266+
validator
267+
.send(
268+
ServerResponseStreamItem::Headers(ResponseHeaders::default()),
269+
SendOptions::default()
270+
)
271+
.await
272+
.is_ok()
273+
);
274+
assert!(
275+
validator
276+
.send(
277+
ServerResponseStreamItem::Headers(ResponseHeaders::default()),
278+
SendOptions::default()
279+
)
280+
.await
281+
.is_err()
282+
);
283+
}
284+
285+
#[tokio::test]
286+
async fn test_send_validator_state_transitions_to_done_on_error() {
287+
let mock = FailingMockSendStream;
288+
let mut validator = ServerSendStreamValidator::new(mock);
289+
290+
assert!(
291+
validator
292+
.send(
293+
ServerResponseStreamItem::Headers(ResponseHeaders::default()),
294+
SendOptions::default()
295+
)
296+
.await
297+
.is_err()
298+
);
299+
assert_eq!(validator.state, SendStreamState::Done);
300+
}
301+
302+
struct MockRecvStream {
303+
items: Vec<Option<Result<(), ()>>>,
304+
index: usize,
305+
}
306+
307+
impl MockRecvStream {
308+
fn new(items: Vec<Option<Result<(), ()>>>) -> Self {
309+
Self { items, index: 0 }
310+
}
311+
}
312+
313+
impl RecvStream for MockRecvStream {
314+
async fn next(&mut self, _msg: &mut dyn RecvMessage) -> Option<Result<(), ()>> {
315+
if self.index < self.items.len() {
316+
let res = self.items[self.index];
317+
self.index += 1;
318+
res
319+
} else {
320+
None
321+
}
322+
}
323+
}
324+
325+
struct NopRecvMessage;
326+
impl RecvMessage for NopRecvMessage {
327+
fn decode(&mut self, _data: &mut dyn bytes::Buf) -> Result<(), String> {
328+
Ok(())
329+
}
330+
}
331+
332+
#[tokio::test]
333+
async fn test_recv_validator_valid_unary() {
334+
let mock = MockRecvStream::new(vec![Some(Ok(())), None]);
335+
let mut validator = ServerRecvStreamValidator::new(mock);
336+
let mut msg = NopRecvMessage;
337+
338+
assert!(matches!(validator.next(&mut msg).await, Some(Ok(()))));
339+
assert!(validator.next(&mut msg).await.is_none());
340+
}
341+
342+
#[tokio::test]
343+
async fn test_recv_validator_empty_stream() {
344+
let mock = MockRecvStream::new(vec![None]);
345+
let mut validator = ServerRecvStreamValidator::new(mock);
346+
let mut msg = NopRecvMessage;
347+
348+
assert!(validator.next(&mut msg).await.is_none());
349+
}
350+
351+
#[tokio::test]
352+
async fn test_recv_validator_error_after_done() {
353+
let mock = MockRecvStream::new(vec![None]);
354+
let mut validator = ServerRecvStreamValidator::new(mock);
355+
let mut msg = NopRecvMessage;
356+
357+
assert!(validator.next(&mut msg).await.is_none());
358+
assert!(matches!(validator.next(&mut msg).await, Some(Err(()))));
359+
}
360+
361+
#[tokio::test]
362+
async fn test_recv_validator_valid_streaming() {
363+
let mock = MockRecvStream::new(vec![Some(Ok(())), Some(Ok(())), None]);
364+
let mut validator = ServerRecvStreamValidator::new(mock);
365+
let mut msg = NopRecvMessage;
366+
367+
assert!(matches!(validator.next(&mut msg).await, Some(Ok(()))));
368+
assert!(matches!(validator.next(&mut msg).await, Some(Ok(()))));
369+
assert!(validator.next(&mut msg).await.is_none());
370+
}
371+
372+
#[tokio::test]
373+
async fn test_recv_validator_terminal_error() {
374+
let mock = MockRecvStream::new(vec![Some(Err(()))]);
375+
let mut validator = ServerRecvStreamValidator::new(mock);
376+
let mut msg = NopRecvMessage;
377+
378+
assert!(matches!(validator.next(&mut msg).await, Some(Err(()))));
379+
380+
// Further calls should return the same error
381+
assert!(matches!(validator.next(&mut msg).await, Some(Err(()))));
382+
}
383+
}

0 commit comments

Comments
 (0)