Skip to content

Commit 2c39823

Browse files
otel/xxray: add custom X-Ray propagator that preserves X-Amzn-Trace-Id
Updates #3663 Signed-off-by: Alexander Yastrebov <alexander.yastrebov@zalando.de>
1 parent 859fe9b commit 2c39823

2 files changed

Lines changed: 185 additions & 5 deletions

File tree

otel/otel.go

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ import (
88
"errors"
99
"fmt"
1010
"os"
11+
"slices"
12+
"strings"
1113
"time"
1214

1315
"go.opentelemetry.io/contrib/exporters/autoexport"
1416
"go.opentelemetry.io/contrib/propagators/autoprop"
17+
"go.opentelemetry.io/contrib/propagators/aws/xray"
1518
"go.opentelemetry.io/otel"
1619
"go.opentelemetry.io/otel/attribute"
1720
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
@@ -23,6 +26,7 @@ import (
2326

2427
"github.com/bombsimon/logrusr/v4"
2528
"github.com/sirupsen/logrus"
29+
"github.com/zalando/skipper/otel/xxray"
2630
)
2731

2832
var log = logrus.WithField("package", "otel")
@@ -53,6 +57,10 @@ type BatchSpanProcessor struct {
5357
MaxExportBatchSize int `yaml:"maxExportBatchSize"`
5458
}
5559

60+
func init() {
61+
autoprop.RegisterTextMapPropagator("xxray", xxray.NewPropagator())
62+
}
63+
5664
// Init bootstraps the OpenTelemetry pipeline using environment variables and provided options.
5765
// Make sure to call shutdown for proper cleanup if err is nil.
5866
//
@@ -129,17 +137,22 @@ func Init(ctx context.Context, o *Options) (shutdown func(context.Context) error
129137
return handleErr(err)
130138
}
131139

132-
tracerProvider := trace.NewTracerProvider(batcherOpt, resourceOpt)
133-
shutdownFuncs = append(shutdownFuncs, tracerProvider.Shutdown)
134-
135-
otel.SetTracerProvider(tracerProvider)
136-
137140
propagator, err := textMapPropagator(o)
138141
if err != nil {
139142
return handleErr(err)
140143
}
141144
otel.SetTextMapPropagator(propagator)
142145

146+
var idGenerator trace.IDGenerator
147+
if hasPropagator("xray", o) || hasPropagator("xxray", o) {
148+
idGenerator = xray.NewIDGenerator()
149+
}
150+
151+
tracerProvider := trace.NewTracerProvider(batcherOpt, resourceOpt, trace.WithIDGenerator(idGenerator))
152+
shutdownFuncs = append(shutdownFuncs, tracerProvider.Shutdown)
153+
154+
otel.SetTracerProvider(tracerProvider)
155+
143156
otel.SetErrorHandler(otel.ErrorHandlerFunc(func(err error) { log.Error(err) }))
144157
otel.SetLogger(logrusr.New(log))
145158

@@ -246,6 +259,14 @@ func textMapPropagator(o *Options) (propagation.TextMapPropagator, error) {
246259
}
247260
}
248261

262+
func hasPropagator(name string, o *Options) bool {
263+
if len(o.Propagators) > 0 {
264+
return slices.Contains(o.Propagators, name)
265+
} else {
266+
return slices.Contains(strings.Split(os.Getenv("OTEL_PROPAGATORS"), ","), name)
267+
}
268+
}
269+
249270
func skipperDebugSpanExporter(ctx context.Context) (trace.SpanExporter, error) {
250271
return stdouttrace.New(stdouttrace.WithWriter(writerFunc(func(p []byte) (int, error) {
251272
log.Debugf("Span: %s", p)

otel/xxray/propagator.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
package xxray
2+
3+
import (
4+
"context"
5+
"errors"
6+
"strings"
7+
8+
"go.opentelemetry.io/contrib/propagators/aws/xray"
9+
"go.opentelemetry.io/otel/propagation"
10+
"go.opentelemetry.io/otel/trace"
11+
)
12+
13+
// Propagator is an AWS X-Ray trace propagator that extends the standard [xray.Propagator].
14+
// Standard propagator requires both Root and Parent keys to be present in the X-Amzn-Trace-Id header
15+
// to successfully extract span context.
16+
// AWS [ALB request tracing] creates X-Amzn-Trace-Id header with only Root field - this propagator
17+
// can re-use it to obtain trace ID value.
18+
//
19+
// [ALB request tracing]: https://docs.aws.amazon.com/elasticloadbalancing/latest/application/load-balancer-request-tracing.html
20+
type Propagator struct {
21+
xray.Propagator
22+
idGenerator *xray.IDGenerator
23+
}
24+
25+
func NewPropagator() *Propagator {
26+
return &Propagator{idGenerator: xray.NewIDGenerator()}
27+
}
28+
29+
func (p *Propagator) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context {
30+
newCtx := p.Propagator.Extract(ctx, carrier)
31+
// If failed to extract span context, try to re-use trace id
32+
if newCtx == ctx {
33+
if header := carrier.Get(traceHeaderKey); header != "" {
34+
tsc, err := extract(header)
35+
if err == nil && tsc.TraceID().IsValid() {
36+
// Re-use only trace id
37+
return trace.ContextWithRemoteSpanContext(ctx, trace.NewSpanContext(trace.SpanContextConfig{
38+
TraceID: tsc.TraceID(),
39+
SpanID: p.idGenerator.NewSpanID(ctx, tsc.TraceID()),
40+
}))
41+
}
42+
}
43+
}
44+
return newCtx
45+
}
46+
47+
// The rest is copied from https://github.com/open-telemetry/opentelemetry-go-contrib/blob/80c9316336ebb4f4c67d8e1011a3add889213fb7/propagators/aws/xray/propagator.go
48+
const (
49+
traceHeaderKey = "X-Amzn-Trace-Id"
50+
traceHeaderDelimiter = ";"
51+
kvDelimiter = "="
52+
traceIDKey = "Root"
53+
sampleFlagKey = "Sampled"
54+
parentIDKey = "Parent"
55+
traceIDVersion = "1"
56+
traceIDDelimiter = "-"
57+
isSampled = "1"
58+
notSampled = "0"
59+
60+
traceFlagNone = 0x0
61+
traceFlagSampled = 0x1 << 0
62+
traceIDLength = 35
63+
traceIDDelimitterIndex1 = 1
64+
traceIDDelimitterIndex2 = 10
65+
traceIDFirstPartLength = 8
66+
sampledFlagLength = 1
67+
)
68+
69+
var (
70+
empty = trace.SpanContext{}
71+
errInvalidTraceHeader = errors.New("invalid X-Amzn-Trace-Id header value, should contain 3 different part separated by ;")
72+
errMalformedTraceID = errors.New("cannot decode trace ID from header")
73+
errLengthTraceIDHeader = errors.New("incorrect length of X-Ray trace ID found, 35 character length expected")
74+
errInvalidTraceIDVersion = errors.New("invalid X-Ray trace ID header found, does not have valid trace ID version")
75+
errInvalidSpanIDLength = errors.New("invalid span ID length, must be 16")
76+
)
77+
78+
// extract extracts Span Context from context.
79+
func extract(headerVal string) (trace.SpanContext, error) {
80+
var (
81+
scc = trace.SpanContextConfig{}
82+
err error
83+
delimiterIndex int
84+
part string
85+
)
86+
pos := 0
87+
for pos < len(headerVal) {
88+
delimiterIndex = indexOf(headerVal, traceHeaderDelimiter, pos)
89+
if delimiterIndex >= 0 {
90+
part = headerVal[pos:delimiterIndex]
91+
pos = delimiterIndex + 1
92+
} else {
93+
// last part
94+
part = strings.TrimSpace(headerVal[pos:])
95+
pos = len(headerVal)
96+
}
97+
equalsIndex := strings.Index(part, kvDelimiter)
98+
if equalsIndex < 0 {
99+
return empty, errInvalidTraceHeader
100+
}
101+
value := part[equalsIndex+1:]
102+
switch {
103+
case strings.HasPrefix(part, traceIDKey):
104+
scc.TraceID, err = parseTraceID(value)
105+
if err != nil {
106+
return empty, err
107+
}
108+
case strings.HasPrefix(part, parentIDKey):
109+
// extract parentId
110+
scc.SpanID, err = trace.SpanIDFromHex(value)
111+
if err != nil {
112+
return empty, errInvalidSpanIDLength
113+
}
114+
case strings.HasPrefix(part, sampleFlagKey):
115+
// extract traceflag
116+
scc.TraceFlags = parseTraceFlag(value)
117+
}
118+
}
119+
return trace.NewSpanContext(scc), nil
120+
}
121+
122+
// indexOf returns position of the first occurrence of a substr in str starting at pos index.
123+
func indexOf(str, substr string, pos int) int {
124+
index := strings.Index(str[pos:], substr)
125+
if index > -1 {
126+
index += pos
127+
}
128+
return index
129+
}
130+
131+
// parseTraceID returns trace ID if valid else return invalid trace ID.
132+
func parseTraceID(xrayTraceID string) (trace.TraceID, error) {
133+
if len(xrayTraceID) != traceIDLength {
134+
return empty.TraceID(), errLengthTraceIDHeader
135+
}
136+
if !strings.HasPrefix(xrayTraceID, traceIDVersion) {
137+
return empty.TraceID(), errInvalidTraceIDVersion
138+
}
139+
140+
if xrayTraceID[traceIDDelimitterIndex1:traceIDDelimitterIndex1+1] != traceIDDelimiter ||
141+
xrayTraceID[traceIDDelimitterIndex2:traceIDDelimitterIndex2+1] != traceIDDelimiter {
142+
return empty.TraceID(), errMalformedTraceID
143+
}
144+
145+
epochPart := xrayTraceID[traceIDDelimitterIndex1+1 : traceIDDelimitterIndex2]
146+
uniquePart := xrayTraceID[traceIDDelimitterIndex2+1 : traceIDLength]
147+
148+
result := epochPart + uniquePart
149+
return trace.TraceIDFromHex(result)
150+
}
151+
152+
// parseTraceFlag returns a parsed trace flag.
153+
func parseTraceFlag(xraySampledFlag string) trace.TraceFlags {
154+
// Use a direct comparison here (#7262).
155+
if xraySampledFlag == isSampled {
156+
return trace.FlagsSampled
157+
}
158+
return trace.FlagsSampled.WithSampled(false)
159+
}

0 commit comments

Comments
 (0)