Skip to content

Commit d097bca

Browse files
authored
refactor: extract endpoint resolution into dedicated middleware (#1558)
1 parent 3394146 commit d097bca

19 files changed

Lines changed: 553 additions & 1119 deletions

interceptor/forward_wait_func.go

Lines changed: 0 additions & 19 deletions
This file was deleted.

interceptor/handler/upstream.go

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ import (
44
"errors"
55
"net/http"
66
"net/http/httputil"
7+
"time"
78

8-
"go.opentelemetry.io/otel"
9+
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
910
"go.opentelemetry.io/otel/attribute"
10-
"go.opentelemetry.io/otel/propagation"
1111
"go.opentelemetry.io/otel/trace"
1212

1313
"github.com/kedacore/http-add-on/interceptor/config"
14+
kedahttp "github.com/kedacore/http-add-on/pkg/http"
15+
"github.com/kedacore/http-add-on/pkg/k8s"
1416
"github.com/kedacore/http-add-on/pkg/util"
1517
)
1618

@@ -21,16 +23,16 @@ var (
2123
)
2224

2325
type Upstream struct {
24-
roundTripper http.RoundTripper
25-
tracingCfg config.Tracing
26-
shouldFallback bool
26+
transportPool *kedahttp.TransportPool
27+
tracingCfg config.Tracing
28+
respHeaderTimeout time.Duration
2729
}
2830

29-
func NewUpstream(roundTripper http.RoundTripper, tracingCfg config.Tracing, shouldFallback bool) *Upstream {
31+
func NewUpstream(baseTransport *http.Transport, tracingCfg config.Tracing, respHeaderTimeout time.Duration) *Upstream {
3032
return &Upstream{
31-
roundTripper: roundTripper,
32-
tracingCfg: tracingCfg,
33-
shouldFallback: shouldFallback,
33+
transportPool: kedahttp.NewTransportPool(baseTransport),
34+
tracingCfg: tracingCfg,
35+
respHeaderTimeout: respHeaderTimeout,
3436
}
3537
}
3638

@@ -41,24 +43,14 @@ func (uh *Upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4143
ctx := r.Context()
4244

4345
if uh.tracingCfg.Enabled {
44-
p := otel.GetTextMapPropagator()
45-
ctx = p.Extract(ctx, propagation.HeaderCarrier(r.Header))
46-
47-
p.Inject(ctx, propagation.HeaderCarrier(w.Header()))
48-
4946
span := trace.SpanFromContext(ctx)
50-
defer span.End()
51-
52-
serviceValAttr := attribute.String("service", "keda-http-interceptor-proxy-upstream")
53-
coldStartValAttr := attribute.String("cold-start", w.Header().Get("X-KEDA-HTTP-Cold-Start"))
54-
55-
span.SetAttributes(serviceValAttr, coldStartValAttr)
47+
span.SetAttributes(
48+
attribute.String("service", "keda-http-interceptor-proxy-upstream"),
49+
attribute.String("cold-start", w.Header().Get(kedahttp.HeaderColdStart)),
50+
)
5651
}
5752

5853
url := util.UpstreamURLFromContext(ctx)
59-
if uh.shouldFallback {
60-
url = util.FallbackURLFromContext(ctx)
61-
}
6254

6355
if url == nil {
6456
sh := NewStatic(http.StatusInternalServerError, errNilUpstreamURL)
@@ -67,6 +59,27 @@ func (uh *Upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
6759
return
6860
}
6961

62+
respHeaderTimeout := uh.respHeaderTimeout
63+
// TODO(v1): remove timeout compatibility fallback for HTTPSO before v1 release
64+
if ir := util.InterceptorRouteFromContext(ctx); ir != nil {
65+
if v, ok := ir.Annotations[k8s.AnnotationResponseHeaderTimeout]; ok {
66+
if d, err := time.ParseDuration(v); err == nil && d > 0 {
67+
respHeaderTimeout = d
68+
}
69+
}
70+
}
71+
transport := uh.transportPool.Get(respHeaderTimeout)
72+
73+
var rt http.RoundTripper = transport
74+
if uh.tracingCfg.Enabled {
75+
rt = otelhttp.NewTransport(transport)
76+
}
77+
78+
rc := http.NewResponseController(w)
79+
if err := rc.EnableFullDuplex(); err != nil {
80+
util.LoggerFromContext(ctx).Error(err, "could not enable full duplex on response writer, continuing")
81+
}
82+
7083
proxy := &httputil.ReverseProxy{
7184
Rewrite: func(pr *httputil.ProxyRequest) {
7285
pr.SetURL(url)
@@ -84,7 +97,7 @@ func (uh *Upstream) ServeHTTP(w http.ResponseWriter, r *http.Request) {
8497
}
8598
},
8699
BufferPool: bufferPool,
87-
Transport: uh.roundTripper,
100+
Transport: rt,
88101
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
89102
sh := NewStatic(http.StatusBadGateway, err)
90103
sh.ServeHTTP(w, r)

interceptor/handler/upstream_test.go

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func TestW3CPropagation(t *testing.T) {
5555
receivedRequest := microservice.IncomingRequests()[0]
5656
receivedHeaders := receivedRequest.Header
5757

58-
r.Equal(receivedHeaders.Get("Traceparent"), traceParent)
58+
r.Contains(receivedHeaders.Get("Traceparent"), fullW3CLengthTraceID)
5959

6060
r.NotContains(receivedHeaders, "B3")
6161
r.NotContains(receivedHeaders, "b3")
@@ -67,9 +67,7 @@ func TestW3CPropagation(t *testing.T) {
6767
_ = tracerProvider.ForceFlush(request.Context())
6868

6969
exportedSpans := exporter.GetSpans()
70-
if len(exportedSpans) != 1 {
71-
t.Fatalf("Expected 1 Span, got %d", len(exportedSpans))
72-
}
70+
r.GreaterOrEqual(len(exportedSpans), 1, "expected at least 1 span")
7371
sc := exportedSpans[0].SpanContext
7472
r.Equal(fullW3CLengthTraceID, sc.TraceID().String())
7573
r.True(sc.IsSampled())
@@ -99,7 +97,7 @@ func TestPropagationWhenNoHeaders(t *testing.T) {
9997
receivedRequest := microservice.IncomingRequests()[0]
10098
receivedHeaders := receivedRequest.Header
10199

102-
r.NotContains(receivedHeaders, "Traceparent")
100+
r.Contains(receivedHeaders, "Traceparent")
103101
r.NotContains(receivedHeaders, "B3")
104102
r.NotContains(receivedHeaders, "b3")
105103
r.NotContains(receivedHeaders, "X-B3-Parentspanid")
@@ -110,22 +108,21 @@ func TestPropagationWhenNoHeaders(t *testing.T) {
110108
_ = tracerProvider.ForceFlush(request.Context())
111109

112110
exportedSpans := exporter.GetSpans()
113-
if len(exportedSpans) != 1 {
114-
t.Fatalf("Expected 1 Span, got %d", len(exportedSpans))
115-
}
111+
r.GreaterOrEqual(len(exportedSpans), 1, "expected at least 1 span")
116112
sc := exportedSpans[0].SpanContext
117113
r.NotEmpty(sc.SpanID())
118114
r.NotEmpty(sc.TraceID())
119115

120116
hasServiceAttribute := false
121117
hasColdStartAttribute := false
122-
for _, attribute := range exportedSpans[0].Attributes {
123-
if attribute.Key == "service" && attribute.Value.AsString() == "keda-http-interceptor-proxy-upstream" {
124-
hasServiceAttribute = true
125-
}
126-
127-
if attribute.Key == "cold-start" {
128-
hasColdStartAttribute = true
118+
for _, s := range exportedSpans {
119+
for _, attribute := range s.Attributes {
120+
if attribute.Key == "service" && attribute.Value.AsString() == "keda-http-interceptor-proxy-upstream" {
121+
hasServiceAttribute = true
122+
}
123+
if attribute.Key == "cold-start" {
124+
hasColdStartAttribute = true
125+
}
129126
}
130127
}
131128
r.True(hasServiceAttribute)
@@ -158,8 +155,7 @@ func TestForwarderSuccess(t *testing.T) {
158155
req = util.RequestWithUpstreamURL(req, forwardURL)
159156
timeouts := defaultTimeouts()
160157
dialCtxFunc := retryDialContextFunc(timeouts)
161-
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
162-
uh := NewUpstream(rt, config.Tracing{}, false)
158+
uh := NewUpstream(newTestTransport(dialCtxFunc), config.Tracing{}, timeouts.ResponseHeader)
163159
uh.ServeHTTP(res, req)
164160

165161
r.True(
@@ -202,8 +198,7 @@ func TestForwarderHeaderTimeout(t *testing.T) {
202198
res, req, err := reqAndRes("/testfwd")
203199
r.NoError(err)
204200
req = util.RequestWithUpstreamURL(req, originURL)
205-
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
206-
uh := NewUpstream(rt, config.Tracing{}, false)
201+
uh := NewUpstream(newTestTransport(dialCtxFunc), config.Tracing{}, timeouts.ResponseHeader)
207202
uh.ServeHTTP(res, req)
208203

209204
forwardedRequests := hdl.IncomingRequests()
@@ -252,8 +247,7 @@ func TestForwarderWaitsForSlowOrigin(t *testing.T) {
252247
res, req, err := reqAndRes(path)
253248
r.NoError(err)
254249
req = util.RequestWithUpstreamURL(req, originURL)
255-
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
256-
uh := NewUpstream(rt, config.Tracing{}, false)
250+
uh := NewUpstream(newTestTransport(dialCtxFunc), config.Tracing{}, timeouts.ResponseHeader)
257251
uh.ServeHTTP(res, req)
258252
// wait for the goroutine above to finish, with a little cusion
259253
ensureSignalBeforeTimeout(originWaitCh, originDelay*2)
@@ -272,8 +266,7 @@ func TestForwarderConnectionRetryAndTimeout(t *testing.T) {
272266
res, req, err := reqAndRes("/test")
273267
r.NoError(err)
274268
req = util.RequestWithUpstreamURL(req, noSuchURL)
275-
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
276-
uh := NewUpstream(rt, config.Tracing{}, false)
269+
uh := NewUpstream(newTestTransport(dialCtxFunc), config.Tracing{}, timeouts.ResponseHeader)
277270

278271
start := time.Now()
279272
uh.ServeHTTP(res, req)
@@ -321,8 +314,7 @@ func TestForwardRequestRedirectAndHeaders(t *testing.T) {
321314
res, req, err := reqAndRes("/testfwd")
322315
r.NoError(err)
323316
req = util.RequestWithUpstreamURL(req, srvURL)
324-
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
325-
uh := NewUpstream(rt, config.Tracing{}, false)
317+
uh := NewUpstream(newTestTransport(dialCtxFunc), config.Tracing{}, timeouts.ResponseHeader)
326318
uh.ServeHTTP(res, req)
327319
r.Equal(301, res.Code)
328320
r.Equal("abc123.com", res.Header().Get("Location"))
@@ -379,7 +371,7 @@ func TestUpstreamPreservesXForwardedHeaders(t *testing.T) {
379371
}
380372

381373
// Configure the Upstream and send a dummy request
382-
upstream := NewUpstream(http.DefaultTransport, config.Tracing{}, false)
374+
upstream := NewUpstream(http.DefaultTransport.(*http.Transport), config.Tracing{}, 500*time.Millisecond)
383375

384376
req := httptest.NewRequest("GET", "/test", nil)
385377
if tt.forwardedFor != "" {
@@ -435,13 +427,9 @@ func TestUpstreamPreservesXForwardedHeaders(t *testing.T) {
435427
}
436428
}
437429

438-
func newRoundTripper(
439-
dialCtxFunc kedanet.DialContextFunc,
440-
httpRespHeaderTimeout time.Duration,
441-
) http.RoundTripper {
430+
func newTestTransport(dialCtxFunc kedanet.DialContextFunc) *http.Transport {
442431
return &http.Transport{
443-
DialContext: dialCtxFunc,
444-
ResponseHeaderTimeout: httpRespHeaderTimeout,
432+
DialContext: dialCtxFunc,
445433
}
446434
}
447435

@@ -485,8 +473,8 @@ func ensureSignalBeforeTimeout(signalCh <-chan struct{}, timeout time.Duration)
485473
func serveHTTP(w http.ResponseWriter, r *http.Request) {
486474
timeouts := defaultTimeouts()
487475
dialCtxFunc := retryDialContextFunc(timeouts)
488-
rt := newRoundTripper(dialCtxFunc, timeouts.ResponseHeader)
489-
upstream := NewUpstream(rt, config.Tracing{Enabled: true}, false)
476+
transport := newTestTransport(dialCtxFunc)
477+
upstream := NewUpstream(transport, config.Tracing{Enabled: true}, timeouts.ResponseHeader)
490478

491479
upstream.ServeHTTP(w, r)
492480
}

interceptor/main.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ func main() {
101101
runtime.Goexit()
102102
}
103103

104-
waitFunc := newWorkloadReplicasForwardWaitFunc(readyCache)
105-
106104
queues := queue.NewMemory()
107105
routingTable := routing.NewTable(ctrlCache, queues)
108106

@@ -218,7 +216,7 @@ func main() {
218216

219217
setupLog.Info("starting the proxy server with TLS enabled", "port", proxyTLSPort)
220218

221-
if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, ctrlCache, timeoutCfg, servingCfg, proxyTLSPort, tlsCfg, tracingCfg); !util.IsIgnoredErr(err) {
219+
if err := runProxyServer(ctx, ctrl.Log, queues, readyCache, routingTable, ctrlCache, timeoutCfg, servingCfg, proxyTLSPort, tlsCfg, tracingCfg); !util.IsIgnoredErr(err) {
222220
setupLog.Error(err, "tls proxy server failed")
223221
return err
224222
}
@@ -230,7 +228,7 @@ func main() {
230228
eg.Go(func() error {
231229
setupLog.Info("starting the proxy server with TLS disabled", "port", servingCfg.ProxyPort)
232230

233-
if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, ctrlCache, timeoutCfg, servingCfg, servingCfg.ProxyPort, nil, tracingCfg); !util.IsIgnoredErr(err) {
231+
if err := runProxyServer(ctx, ctrl.Log, queues, readyCache, routingTable, ctrlCache, timeoutCfg, servingCfg, servingCfg.ProxyPort, nil, tracingCfg); !util.IsIgnoredErr(err) {
234232
setupLog.Error(err, "proxy server failed")
235233
return err
236234
}
@@ -292,7 +290,7 @@ func runProxyServer(
292290
ctx context.Context,
293291
logger logr.Logger,
294292
q queue.Counter,
295-
waitFunc forwardWaitFunc,
293+
readyCache *k8s.ReadyEndpointsCache,
296294
routingTable routing.Table,
297295
reader client.Reader,
298296
timeouts config.Timeouts,
@@ -305,7 +303,7 @@ func runProxyServer(
305303
rootHandler := BuildProxyHandler(&ProxyHandlerConfig{
306304
Logger: logger,
307305
Queue: q,
308-
WaitFunc: waitFunc,
306+
ReadyCache: readyCache,
309307
RoutingTable: routingTable,
310308
Reader: reader,
311309
Timeouts: timeouts,

interceptor/middleware/counting.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@ import (
1010
)
1111

1212
type Counting struct {
13-
queueCounter queue.Counter
14-
upstreamHandler http.Handler
13+
next http.Handler
14+
queueCounter queue.Counter
1515
}
1616

17-
func NewCountingMiddleware(queueCounter queue.Counter, upstreamHandler http.Handler) *Counting {
17+
func NewCountingMiddleware(next http.Handler, queueCounter queue.Counter) *Counting {
1818
return &Counting{
19-
queueCounter: queueCounter,
20-
upstreamHandler: upstreamHandler,
19+
next: next,
20+
queueCounter: queueCounter,
2121
}
2222
}
2323

@@ -32,7 +32,7 @@ func (cm *Counting) ServeHTTP(w http.ResponseWriter, r *http.Request) {
3232

3333
if err := cm.queueCounter.Increase(key, 1); err != nil {
3434
util.LoggerFromContext(ctx).Error(err, "error incrementing queue counter", "key", key)
35-
cm.upstreamHandler.ServeHTTP(w, r)
35+
cm.next.ServeHTTP(w, r)
3636
return
3737
}
3838
metrics.RecordPendingRequestCount(key, 1)
@@ -44,5 +44,5 @@ func (cm *Counting) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4444
metrics.RecordPendingRequestCount(key, -1)
4545
}()
4646

47-
cm.upstreamHandler.ServeHTTP(w, r)
47+
cm.next.ServeHTTP(w, r)
4848
}

interceptor/middleware/counting_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ func TestCountMiddleware(t *testing.T) {
3939

4040
var concurrencyDuringRequest int
4141
middleware := NewCountingMiddleware(
42-
queueCounter,
4342
http.HandlerFunc(func(wr http.ResponseWriter, _ *http.Request) {
4443
counts, err := queueCounter.Current()
4544
if err == nil {
@@ -48,6 +47,7 @@ func TestCountMiddleware(t *testing.T) {
4847
wr.WriteHeader(200)
4948
_, _ = wr.Write([]byte("OK"))
5049
}),
50+
queueCounter,
5151
)
5252

5353
req, err := http.NewRequest("GET", "/something", nil)

0 commit comments

Comments
 (0)