@@ -6,13 +6,19 @@ use quinn::{Connection, Endpoint};
66use std:: net:: { SocketAddr , SocketAddrV4 , SocketAddrV6 } ;
77use std:: ops:: Deref ;
88use std:: sync:: Arc ;
9+ use std:: sync:: atomic:: { AtomicBool , Ordering } ;
910use tracing:: { debug, instrument, warn} ;
1011use url:: Host ;
1112
1213#[ derive( Clone ) ]
1314pub struct QuicConnection {
15+ inner : Arc < QuicConnectionInner > ,
16+ }
17+
18+ pub struct QuicConnectionInner {
1419 config : Arc < WsClientConfig > ,
1520 endpoint : Endpoint ,
21+ is_broken : AtomicBool ,
1622}
1723
1824impl QuicConnection {
@@ -55,15 +61,21 @@ impl QuicConnection {
5561 let endpoint = Endpoint :: new ( quinn:: EndpointConfig :: default ( ) , None , socket, Arc :: new ( quinn:: TokioRuntime ) )
5662 . expect ( "Failed to create QUIC endpoint" ) ;
5763
58- Self { config, endpoint }
64+ Self {
65+ inner : Arc :: new ( QuicConnectionInner {
66+ config,
67+ endpoint,
68+ is_broken : AtomicBool :: new ( false ) ,
69+ } ) ,
70+ }
5971 }
6072}
6173
6274impl Deref for QuicConnection {
63- type Target = WsClientConfig ;
75+ type Target = QuicConnectionInner ;
6476
6577 fn deref ( & self ) -> & Self :: Target {
66- & self . config
78+ & self . inner
6779 }
6880}
6981
@@ -74,12 +86,15 @@ impl ManageConnection for QuicConnection {
7486 #[ instrument( level = "trace" , name = "quic_cnx_server" , skip_all) ]
7587 async fn connect ( & self ) -> Result < Self :: Connection , Self :: Error > {
7688 // 1. Resolve DNS
77- let host = self . remote_addr . host ( ) ;
78- let port = self . remote_addr . port ( ) ;
89+ self . inner . is_broken . store ( false , Ordering :: SeqCst ) ;
90+ let host = self . inner . config . remote_addr . host ( ) ;
91+ let port = self . inner . config . remote_addr . port ( ) ;
7992
8093 let remote_addr = match host {
8194 Host :: Domain ( domain) => {
8295 let addrs = self
96+ . inner
97+ . config
8398 . dns_resolver
8499 . lookup_host ( domain, port)
85100 . await
@@ -95,6 +110,8 @@ impl ManageConnection for QuicConnection {
95110
96111 // 2. Get TLS configuration
97112 let tls_config = self
113+ . inner
114+ . config
98115 . remote_addr
99116 . tls ( )
100117 . ok_or_else ( || anyhow ! ( "QUIC requires TLS configuration" ) ) ?;
@@ -111,7 +128,7 @@ impl ManageConnection for QuicConnection {
111128 debug ! (
112129 "Creating QUIC client config for {} (SNI: {:?}), mTLS: {}" ,
113130 remote_addr,
114- self . tls_server_name( ) ,
131+ self . inner . config . tls_server_name( ) ,
115132 tls_client_certificate. is_some( )
116133 ) ;
117134
@@ -134,6 +151,8 @@ impl ManageConnection for QuicConnection {
134151 // Configure max idle timeout
135152 // Use 10 minutes by default to support long-lived reverse tunnels and file transfers
136153 let idle_timeout = self
154+ . inner
155+ . config
137156 . quic_max_idle_timeout
138157 . unwrap_or ( std:: time:: Duration :: from_secs ( 600 ) ) ;
139158 debug ! ( "QUIC idle timeout: {}s" , idle_timeout. as_secs( ) ) ;
@@ -142,13 +161,19 @@ impl ManageConnection for QuicConnection {
142161 ) ) ) ;
143162
144163 // Configure keep-alive interval
145- debug ! ( "QUIC keep-alive interval: {}s" , self . quic_keep_alive_interval. as_secs( ) ) ;
146- transport_config. keep_alive_interval ( Some ( self . quic_keep_alive_interval ) ) ;
164+ debug ! (
165+ "QUIC keep-alive interval: {}s" ,
166+ self . inner. config. quic_keep_alive_interval. as_secs( )
167+ ) ;
168+ transport_config. keep_alive_interval ( Some ( self . inner . config . quic_keep_alive_interval ) ) ;
147169
148170 // Configure stream limits
149- debug ! ( "QUIC concurrent streams: {} bidirectional" , self . quic_max_concurrent_bi_streams) ;
171+ debug ! (
172+ "QUIC concurrent streams: {} bidirectional" ,
173+ self . inner. config. quic_max_concurrent_bi_streams
174+ ) ;
150175 transport_config. max_concurrent_bidi_streams (
151- quinn:: VarInt :: from_u64 ( self . quic_max_concurrent_bi_streams )
176+ quinn:: VarInt :: from_u64 ( self . inner . config . quic_max_concurrent_bi_streams )
152177 . expect ( "QUIC concurrent bidirectional streams limit too large" ) ,
153178 ) ;
154179 transport_config. max_concurrent_uni_streams ( 0u32 . into ( ) ) ; // We don't use unidirectional streams
@@ -157,20 +182,21 @@ impl ManageConnection for QuicConnection {
157182 // Connection-level flow control (total data across all streams)
158183 debug ! (
159184 "QUIC flow control - connection: {} bytes, stream: {} bytes" ,
160- self . quic_initial_max_data, self . quic_initial_max_stream_data
185+ self . inner . config . quic_initial_max_data, self . inner . config . quic_initial_max_stream_data
161186 ) ;
162187 transport_config. receive_window (
163- quinn:: VarInt :: from_u64 ( self . quic_initial_max_data ) . expect ( "QUIC initial max data limit too large" ) ,
188+ quinn:: VarInt :: from_u64 ( self . inner . config . quic_initial_max_data )
189+ . expect ( "QUIC initial max data limit too large" ) ,
164190 ) ;
165- transport_config. send_window ( self . quic_initial_max_data ) ;
191+ transport_config. send_window ( self . inner . config . quic_initial_max_data ) ;
166192
167193 // Per-stream flow control
168194 transport_config. stream_receive_window (
169- quinn:: VarInt :: from_u64 ( self . quic_initial_max_stream_data )
195+ quinn:: VarInt :: from_u64 ( self . inner . config . quic_initial_max_stream_data )
170196 . expect ( "QUIC initial max stream data limit too large" ) ,
171197 ) ;
172198
173- if let Some ( mtu) = self . quic_initial_mtu {
199+ if let Some ( mtu) = self . inner . config . quic_initial_mtu {
174200 transport_config. initial_mtu ( mtu) ;
175201 }
176202
@@ -180,11 +206,13 @@ impl ManageConnection for QuicConnection {
180206 debug ! (
181207 "Initiating QUIC connection to {} (SNI: {:?})" ,
182208 remote_addr,
183- self . tls_server_name( )
209+ self . inner . config . tls_server_name( )
184210 ) ;
185- let connecting =
186- self . endpoint
187- . connect_with ( client_config, remote_addr, self . tls_server_name ( ) . to_str ( ) . as_ref ( ) ) ?;
211+ let connecting = self . endpoint . connect_with (
212+ client_config,
213+ remote_addr,
214+ self . inner . config . tls_server_name ( ) . to_str ( ) . as_ref ( ) ,
215+ ) ?;
188216
189217 debug ! ( "Waiting for QUIC handshake to complete..." ) ;
190218 let connection = match connecting. await {
@@ -223,9 +251,20 @@ impl ManageConnection for QuicConnection {
223251 }
224252
225253 fn has_broken ( & self , conn : & mut Self :: Connection ) -> bool {
254+ if self . inner . is_broken . load ( Ordering :: SeqCst ) {
255+ warn ! ( "Connection pool: Connection marked as broken, discarding" ) ;
256+ return true ;
257+ }
258+
226259 match conn {
227- Some ( c) => c. close_reason ( ) . is_some ( ) ,
228- None => true ,
260+ Some ( c) => {
261+ if c. close_reason ( ) . is_some ( ) {
262+ warn ! ( "Connection pool: Connection has close_reason, discarding" ) ;
263+ return true ;
264+ }
265+ false
266+ }
267+ None => true , // No connection, so it's "broken"
229268 }
230269 }
231270}
0 commit comments