@@ -55,7 +55,7 @@ func TestW3CPropagation(t *testing.T) {
5555 receivedRequest := microservice .IncomingRequests ()[0 ]
5656 receivedHeaders := receivedRequest .Header
5757
58- r .Equal (receivedHeaders .Get ("Traceparent" ), traceParent )
58+ r .Contains (receivedHeaders .Get ("Traceparent" ), fullW3CLengthTraceID )
5959
6060 r .NotContains (receivedHeaders , "B3" )
6161 r .NotContains (receivedHeaders , "b3" )
@@ -67,9 +67,7 @@ func TestW3CPropagation(t *testing.T) {
6767 _ = tracerProvider .ForceFlush (request .Context ())
6868
6969 exportedSpans := exporter .GetSpans ()
70- if len (exportedSpans ) != 1 {
71- t .Fatalf ("Expected 1 Span, got %d" , len (exportedSpans ))
72- }
70+ r .GreaterOrEqual (len (exportedSpans ), 1 , "expected at least 1 span" )
7371 sc := exportedSpans [0 ].SpanContext
7472 r .Equal (fullW3CLengthTraceID , sc .TraceID ().String ())
7573 r .True (sc .IsSampled ())
@@ -99,7 +97,7 @@ func TestPropagationWhenNoHeaders(t *testing.T) {
9997 receivedRequest := microservice .IncomingRequests ()[0 ]
10098 receivedHeaders := receivedRequest .Header
10199
102- r .NotContains (receivedHeaders , "Traceparent" )
100+ r .Contains (receivedHeaders , "Traceparent" )
103101 r .NotContains (receivedHeaders , "B3" )
104102 r .NotContains (receivedHeaders , "b3" )
105103 r .NotContains (receivedHeaders , "X-B3-Parentspanid" )
@@ -110,22 +108,21 @@ func TestPropagationWhenNoHeaders(t *testing.T) {
110108 _ = tracerProvider .ForceFlush (request .Context ())
111109
112110 exportedSpans := exporter .GetSpans ()
113- if len (exportedSpans ) != 1 {
114- t .Fatalf ("Expected 1 Span, got %d" , len (exportedSpans ))
115- }
111+ r .GreaterOrEqual (len (exportedSpans ), 1 , "expected at least 1 span" )
116112 sc := exportedSpans [0 ].SpanContext
117113 r .NotEmpty (sc .SpanID ())
118114 r .NotEmpty (sc .TraceID ())
119115
120116 hasServiceAttribute := false
121117 hasColdStartAttribute := false
122- for _ , attribute := range exportedSpans [0 ].Attributes {
123- if attribute .Key == "service" && attribute .Value .AsString () == "keda-http-interceptor-proxy-upstream" {
124- hasServiceAttribute = true
125- }
126-
127- if attribute .Key == "cold-start" {
128- hasColdStartAttribute = true
118+ for _ , s := range exportedSpans {
119+ for _ , attribute := range s .Attributes {
120+ if attribute .Key == "service" && attribute .Value .AsString () == "keda-http-interceptor-proxy-upstream" {
121+ hasServiceAttribute = true
122+ }
123+ if attribute .Key == "cold-start" {
124+ hasColdStartAttribute = true
125+ }
129126 }
130127 }
131128 r .True (hasServiceAttribute )
@@ -158,8 +155,7 @@ func TestForwarderSuccess(t *testing.T) {
158155 req = util .RequestWithUpstreamURL (req , forwardURL )
159156 timeouts := defaultTimeouts ()
160157 dialCtxFunc := retryDialContextFunc (timeouts )
161- rt := newRoundTripper (dialCtxFunc , timeouts .ResponseHeader )
162- uh := NewUpstream (rt , config.Tracing {}, false )
158+ uh := NewUpstream (newTestTransport (dialCtxFunc ), config.Tracing {}, timeouts .ResponseHeader )
163159 uh .ServeHTTP (res , req )
164160
165161 r .True (
@@ -202,8 +198,7 @@ func TestForwarderHeaderTimeout(t *testing.T) {
202198 res , req , err := reqAndRes ("/testfwd" )
203199 r .NoError (err )
204200 req = util .RequestWithUpstreamURL (req , originURL )
205- rt := newRoundTripper (dialCtxFunc , timeouts .ResponseHeader )
206- uh := NewUpstream (rt , config.Tracing {}, false )
201+ uh := NewUpstream (newTestTransport (dialCtxFunc ), config.Tracing {}, timeouts .ResponseHeader )
207202 uh .ServeHTTP (res , req )
208203
209204 forwardedRequests := hdl .IncomingRequests ()
@@ -252,8 +247,7 @@ func TestForwarderWaitsForSlowOrigin(t *testing.T) {
252247 res , req , err := reqAndRes (path )
253248 r .NoError (err )
254249 req = util .RequestWithUpstreamURL (req , originURL )
255- rt := newRoundTripper (dialCtxFunc , timeouts .ResponseHeader )
256- uh := NewUpstream (rt , config.Tracing {}, false )
250+ uh := NewUpstream (newTestTransport (dialCtxFunc ), config.Tracing {}, timeouts .ResponseHeader )
257251 uh .ServeHTTP (res , req )
258252 // wait for the goroutine above to finish, with a little cusion
259253 ensureSignalBeforeTimeout (originWaitCh , originDelay * 2 )
@@ -272,8 +266,7 @@ func TestForwarderConnectionRetryAndTimeout(t *testing.T) {
272266 res , req , err := reqAndRes ("/test" )
273267 r .NoError (err )
274268 req = util .RequestWithUpstreamURL (req , noSuchURL )
275- rt := newRoundTripper (dialCtxFunc , timeouts .ResponseHeader )
276- uh := NewUpstream (rt , config.Tracing {}, false )
269+ uh := NewUpstream (newTestTransport (dialCtxFunc ), config.Tracing {}, timeouts .ResponseHeader )
277270
278271 start := time .Now ()
279272 uh .ServeHTTP (res , req )
@@ -321,8 +314,7 @@ func TestForwardRequestRedirectAndHeaders(t *testing.T) {
321314 res , req , err := reqAndRes ("/testfwd" )
322315 r .NoError (err )
323316 req = util .RequestWithUpstreamURL (req , srvURL )
324- rt := newRoundTripper (dialCtxFunc , timeouts .ResponseHeader )
325- uh := NewUpstream (rt , config.Tracing {}, false )
317+ uh := NewUpstream (newTestTransport (dialCtxFunc ), config.Tracing {}, timeouts .ResponseHeader )
326318 uh .ServeHTTP (res , req )
327319 r .Equal (301 , res .Code )
328320 r .Equal ("abc123.com" , res .Header ().Get ("Location" ))
@@ -379,7 +371,7 @@ func TestUpstreamPreservesXForwardedHeaders(t *testing.T) {
379371 }
380372
381373 // Configure the Upstream and send a dummy request
382- upstream := NewUpstream (http .DefaultTransport , config.Tracing {}, false )
374+ upstream := NewUpstream (http .DefaultTransport .( * http. Transport ) , config.Tracing {}, 500 * time . Millisecond )
383375
384376 req := httptest .NewRequest ("GET" , "/test" , nil )
385377 if tt .forwardedFor != "" {
@@ -435,13 +427,9 @@ func TestUpstreamPreservesXForwardedHeaders(t *testing.T) {
435427 }
436428}
437429
438- func newRoundTripper (
439- dialCtxFunc kedanet.DialContextFunc ,
440- httpRespHeaderTimeout time.Duration ,
441- ) http.RoundTripper {
430+ func newTestTransport (dialCtxFunc kedanet.DialContextFunc ) * http.Transport {
442431 return & http.Transport {
443- DialContext : dialCtxFunc ,
444- ResponseHeaderTimeout : httpRespHeaderTimeout ,
432+ DialContext : dialCtxFunc ,
445433 }
446434}
447435
@@ -485,8 +473,8 @@ func ensureSignalBeforeTimeout(signalCh <-chan struct{}, timeout time.Duration)
485473func serveHTTP (w http.ResponseWriter , r * http.Request ) {
486474 timeouts := defaultTimeouts ()
487475 dialCtxFunc := retryDialContextFunc (timeouts )
488- rt := newRoundTripper (dialCtxFunc , timeouts . ResponseHeader )
489- upstream := NewUpstream (rt , config.Tracing {Enabled : true }, false )
476+ transport := newTestTransport (dialCtxFunc )
477+ upstream := NewUpstream (transport , config.Tracing {Enabled : true }, timeouts . ResponseHeader )
490478
491479 upstream .ServeHTTP (w , r )
492480}
0 commit comments