Skip to content

Commit 89b8525

Browse files
committed
feat(grpc): add server-side stream validators
Introduces `ServerSendStreamValidator` and `ServerRecvStreamValidator` to enforce strict gRPC semantics on server streams. These are intended to wrap raw transport streams before passing them to application-level handlers, ensuring that user-provided service logic cannot violate protocol rules. - `ServerSendStreamValidator` ensures proper response sequencing (Headers -> Messages -> Trailers) and prevents invalid state transitions, while fully supporting trailers-only responses. - `ServerRecvStreamValidator` safely manages terminal states, ensuring that any subsequent polls after stream completion consistently return an error to prevent undefined behavior. - Adds comprehensive unit tests to verify state machine correctness and protocol compliance.
1 parent 90ae261 commit 89b8525

2 files changed

Lines changed: 377 additions & 0 deletions

File tree

grpc/src/server/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ use crate::core::ServerResponseStreamItem;
3232
use crate::core::Trailers;
3333
use tokio::sync::oneshot;
3434

35+
pub mod stream_util;
36+
3537
pub struct Server {
3638
handler: Option<Arc<dyn DynHandle>>,
3739
}

grpc/src/server/stream_util.rs

Lines changed: 375 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,375 @@
1+
/*
2+
*
3+
* Copyright 2026 gRPC authors.
4+
*
5+
* Permission is hereby granted, free of charge, to any person obtaining a copy
6+
* of this software and associated documentation files (the "Software"), to
7+
* deal in the Software without restriction, including without limitation the
8+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
9+
* sell copies of the Software, and to permit persons to whom the Software is
10+
* furnished to do so, subject to the following conditions:
11+
*
12+
* The above copyright notice and this permission notice shall be included in
13+
* all copies or substantial portions of the Software.
14+
*
15+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20+
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21+
* IN THE SOFTWARE.
22+
*
23+
*/
24+
25+
use crate::core::RecvMessage;
26+
use crate::core::ServerResponseStreamItem;
27+
use crate::server::RecvStream;
28+
use crate::server::SendOptions;
29+
use crate::server::SendStream;
30+
31+
/// Enforces proper gRPC semantics on the server sending stream.
32+
pub struct ServerSendStreamValidator<S> {
33+
inner: S,
34+
state: SendStreamState,
35+
}
36+
37+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38+
enum SendStreamState {
39+
Init,
40+
HeadersSent,
41+
MessagesSent,
42+
Done,
43+
}
44+
45+
impl<S> ServerSendStreamValidator<S>
46+
where
47+
S: SendStream,
48+
{
49+
/// Constructs a new `ServerSendStreamValidator`.
50+
pub fn new(inner: S) -> Self {
51+
Self {
52+
inner,
53+
state: SendStreamState::Init,
54+
}
55+
}
56+
}
57+
58+
impl<S> SendStream for ServerSendStreamValidator<S>
59+
where
60+
S: SendStream,
61+
{
62+
async fn send<'a>(
63+
&mut self,
64+
item: ServerResponseStreamItem<'a>,
65+
options: SendOptions,
66+
) -> Result<(), ()> {
67+
if self.state == SendStreamState::Done {
68+
return Err(());
69+
}
70+
71+
let next_state = match &item {
72+
ServerResponseStreamItem::Headers(_) => match self.state {
73+
SendStreamState::Init => SendStreamState::HeadersSent,
74+
_ => {
75+
self.state = SendStreamState::Done;
76+
return Err(());
77+
}
78+
},
79+
ServerResponseStreamItem::Message(_) => match self.state {
80+
SendStreamState::HeadersSent | SendStreamState::MessagesSent => {
81+
SendStreamState::MessagesSent
82+
}
83+
_ => {
84+
self.state = SendStreamState::Done;
85+
return Err(());
86+
}
87+
},
88+
};
89+
90+
let res = self.inner.send(item, options).await;
91+
match res {
92+
Ok(()) => self.state = next_state,
93+
Err(_) => self.state = SendStreamState::Done,
94+
}
95+
res
96+
}
97+
}
98+
99+
/// Enforces proper gRPC semantics on the server receiving stream.
100+
pub struct ServerRecvStreamValidator<R> {
101+
inner: R,
102+
state: RecvStreamState,
103+
}
104+
105+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106+
enum RecvStreamState {
107+
Init,
108+
Done,
109+
}
110+
111+
impl<R> ServerRecvStreamValidator<R>
112+
where
113+
R: RecvStream,
114+
{
115+
/// Constructs a new `ServerRecvStreamValidator`.
116+
pub fn new(inner: R) -> Self {
117+
Self {
118+
inner,
119+
state: RecvStreamState::Init,
120+
}
121+
}
122+
}
123+
124+
impl<R> RecvStream for ServerRecvStreamValidator<R>
125+
where
126+
R: RecvStream,
127+
{
128+
async fn next(&mut self, msg: &mut dyn RecvMessage) -> Option<Result<(), ()>> {
129+
if self.state == RecvStreamState::Done {
130+
return Some(Err(()));
131+
}
132+
133+
let res = self.inner.next(msg).await;
134+
135+
match res {
136+
Some(Ok(())) => Some(Ok(())),
137+
None => {
138+
self.state = RecvStreamState::Done;
139+
None
140+
}
141+
Some(Err(())) => {
142+
self.state = RecvStreamState::Done;
143+
Some(Err(()))
144+
}
145+
}
146+
}
147+
}
148+
149+
#[cfg(test)]
150+
mod tests {
151+
use super::*;
152+
use crate::core::ResponseHeaders;
153+
use crate::core::SendMessage;
154+
use bytes::Buf;
155+
use bytes::Bytes;
156+
157+
impl SendMessage for () {
158+
fn encode(&self) -> Result<Box<dyn Buf + Send + Sync>, String> {
159+
Ok(Box::new(Bytes::new()))
160+
}
161+
}
162+
163+
#[derive(Debug, PartialEq, Eq)]
164+
enum SendEvent {
165+
Headers,
166+
Message,
167+
}
168+
169+
struct MockSendStream {
170+
events: Vec<SendEvent>,
171+
}
172+
173+
impl MockSendStream {
174+
fn new() -> Self {
175+
Self { events: Vec::new() }
176+
}
177+
}
178+
179+
impl SendStream for MockSendStream {
180+
async fn send<'a>(
181+
&mut self,
182+
item: ServerResponseStreamItem<'a>,
183+
_options: SendOptions,
184+
) -> Result<(), ()> {
185+
match item {
186+
ServerResponseStreamItem::Headers(_) => self.events.push(SendEvent::Headers),
187+
ServerResponseStreamItem::Message(_) => self.events.push(SendEvent::Message),
188+
}
189+
Ok(())
190+
}
191+
}
192+
193+
#[tokio::test]
194+
async fn test_send_validator_valid_full_stream() {
195+
let mock = MockSendStream::new();
196+
let mut validator = ServerSendStreamValidator::new(mock);
197+
198+
assert!(
199+
validator
200+
.send(
201+
ServerResponseStreamItem::Headers(ResponseHeaders::default()),
202+
SendOptions::default()
203+
)
204+
.await
205+
.is_ok()
206+
);
207+
assert!(
208+
validator
209+
.send(
210+
ServerResponseStreamItem::Message(&()),
211+
SendOptions::default()
212+
)
213+
.await
214+
.is_ok()
215+
);
216+
217+
assert_eq!(
218+
validator.inner.events,
219+
vec![SendEvent::Headers, SendEvent::Message]
220+
);
221+
}
222+
223+
struct FailingMockSendStream;
224+
225+
impl SendStream for FailingMockSendStream {
226+
async fn send<'a>(
227+
&mut self,
228+
_item: ServerResponseStreamItem<'a>,
229+
_options: SendOptions,
230+
) -> Result<(), ()> {
231+
Err(())
232+
}
233+
}
234+
235+
#[tokio::test]
236+
async fn test_send_validator_invalid_message_before_headers() {
237+
let mock = MockSendStream::new();
238+
let mut validator = ServerSendStreamValidator::new(mock);
239+
240+
assert!(
241+
validator
242+
.send(
243+
ServerResponseStreamItem::Message(&()),
244+
SendOptions::default()
245+
)
246+
.await
247+
.is_err()
248+
);
249+
assert_eq!(validator.state, SendStreamState::Done);
250+
}
251+
252+
#[tokio::test]
253+
async fn test_send_validator_invalid_headers_twice() {
254+
let mock = MockSendStream::new();
255+
let mut validator = ServerSendStreamValidator::new(mock);
256+
257+
assert!(
258+
validator
259+
.send(
260+
ServerResponseStreamItem::Headers(ResponseHeaders::default()),
261+
SendOptions::default()
262+
)
263+
.await
264+
.is_ok()
265+
);
266+
assert!(
267+
validator
268+
.send(
269+
ServerResponseStreamItem::Headers(ResponseHeaders::default()),
270+
SendOptions::default()
271+
)
272+
.await
273+
.is_err()
274+
);
275+
}
276+
277+
#[tokio::test]
278+
async fn test_send_validator_state_transitions_to_done_on_error() {
279+
let mock = FailingMockSendStream;
280+
let mut validator = ServerSendStreamValidator::new(mock);
281+
282+
assert!(
283+
validator
284+
.send(
285+
ServerResponseStreamItem::Headers(ResponseHeaders::default()),
286+
SendOptions::default()
287+
)
288+
.await
289+
.is_err()
290+
);
291+
assert_eq!(validator.state, SendStreamState::Done);
292+
}
293+
294+
struct MockRecvStream {
295+
items: Vec<Option<Result<(), ()>>>,
296+
index: usize,
297+
}
298+
299+
impl MockRecvStream {
300+
fn new(items: Vec<Option<Result<(), ()>>>) -> Self {
301+
Self { items, index: 0 }
302+
}
303+
}
304+
305+
impl RecvStream for MockRecvStream {
306+
async fn next(&mut self, _msg: &mut dyn RecvMessage) -> Option<Result<(), ()>> {
307+
if self.index < self.items.len() {
308+
let res = self.items[self.index];
309+
self.index += 1;
310+
res
311+
} else {
312+
None
313+
}
314+
}
315+
}
316+
317+
struct NopRecvMessage;
318+
impl RecvMessage for NopRecvMessage {
319+
fn decode(&mut self, _data: &mut dyn bytes::Buf) -> Result<(), String> {
320+
Ok(())
321+
}
322+
}
323+
324+
#[tokio::test]
325+
async fn test_recv_validator_valid_unary() {
326+
let mock = MockRecvStream::new(vec![Some(Ok(())), None]);
327+
let mut validator = ServerRecvStreamValidator::new(mock);
328+
let mut msg = NopRecvMessage;
329+
330+
assert!(matches!(validator.next(&mut msg).await, Some(Ok(()))));
331+
assert!(validator.next(&mut msg).await.is_none());
332+
}
333+
334+
#[tokio::test]
335+
async fn test_recv_validator_empty_stream() {
336+
let mock = MockRecvStream::new(vec![None]);
337+
let mut validator = ServerRecvStreamValidator::new(mock);
338+
let mut msg = NopRecvMessage;
339+
340+
assert!(validator.next(&mut msg).await.is_none());
341+
}
342+
343+
#[tokio::test]
344+
async fn test_recv_validator_error_after_done() {
345+
let mock = MockRecvStream::new(vec![None]);
346+
let mut validator = ServerRecvStreamValidator::new(mock);
347+
let mut msg = NopRecvMessage;
348+
349+
assert!(validator.next(&mut msg).await.is_none());
350+
assert!(matches!(validator.next(&mut msg).await, Some(Err(()))));
351+
}
352+
353+
#[tokio::test]
354+
async fn test_recv_validator_valid_streaming() {
355+
let mock = MockRecvStream::new(vec![Some(Ok(())), Some(Ok(())), None]);
356+
let mut validator = ServerRecvStreamValidator::new(mock);
357+
let mut msg = NopRecvMessage;
358+
359+
assert!(matches!(validator.next(&mut msg).await, Some(Ok(()))));
360+
assert!(matches!(validator.next(&mut msg).await, Some(Ok(()))));
361+
assert!(validator.next(&mut msg).await.is_none());
362+
}
363+
364+
#[tokio::test]
365+
async fn test_recv_validator_terminal_error() {
366+
let mock = MockRecvStream::new(vec![Some(Err(()))]);
367+
let mut validator = ServerRecvStreamValidator::new(mock);
368+
let mut msg = NopRecvMessage;
369+
370+
assert!(matches!(validator.next(&mut msg).await, Some(Err(()))));
371+
372+
// Further calls should return the same error
373+
assert!(matches!(validator.next(&mut msg).await, Some(Err(()))));
374+
}
375+
}

0 commit comments

Comments
 (0)