package tracing import ( "context" "crypto/rand" "encoding/hex" "testing" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" ) type testCarrier struct { Values map[string]string } func (c *testCarrier) Get(key string) string { return c.Values[key] } func (c *testCarrier) Set(key string, value string) { c.Values[key] = value } func (c *testCarrier) Keys() []string { res := make([]string, 0, len(c.Values)) for k := range c.Values { res = append(res, k) } return res } var p = &propagator{} func TestPropagator_Inject(t *testing.T) { t.Run("injects trace_id and span_id if valid", func(t *testing.T) { traceIDBytes := make([]byte, 16) rand.Read(traceIDBytes) traceIDHex := hex.EncodeToString(traceIDBytes) spanIDBytes := make([]byte, 8) rand.Read(spanIDBytes) spanIDHex := hex.EncodeToString(spanIDBytes) spanConfig := trace.SpanContextConfig{} spanConfig.TraceID, _ = trace.TraceIDFromHex(traceIDHex) spanConfig.SpanID, _ = trace.SpanIDFromHex(spanIDHex) spanConfig.TraceFlags = trace.FlagsSampled ctx := trace.ContextWithRemoteSpanContext(context.Background(), trace.NewSpanContext(spanConfig)) c := &testCarrier{ Values: make(map[string]string), } p.Inject(ctx, c) require.Equal(t, 3, len(c.Values), "not all headers were saved") require.Equal(t, traceIDHex, c.Values[traceIDHeader], "unexpected trace id") require.Equal(t, spanIDHex, c.Values[spanIDHeader], "unexpected span id") require.Equal(t, "1", c.Values[flagsHeader], "unexpected flags") }) t.Run("doesn't injects if trace_id is invalid", func(t *testing.T) { traceIDBytes := make([]byte, 16) traceIDHex := hex.EncodeToString(traceIDBytes) spanIDBytes := make([]byte, 8) rand.Read(spanIDBytes) spanIDHex := hex.EncodeToString(spanIDBytes) spanConfig := trace.SpanContextConfig{} spanConfig.TraceID, _ = trace.TraceIDFromHex(traceIDHex) spanConfig.SpanID, _ = trace.SpanIDFromHex(spanIDHex) spanConfig.TraceFlags = trace.FlagsSampled ctx := trace.ContextWithRemoteSpanContext(context.Background(), trace.NewSpanContext(spanConfig)) c := &testCarrier{ Values: make(map[string]string), } p.Inject(ctx, c) require.Equal(t, 0, len(c.Values), "some headers were saved") }) t.Run("doesn't injects if span_id is invalid", func(t *testing.T) { traceIDBytes := make([]byte, 16) rand.Read(traceIDBytes) traceIDHex := hex.EncodeToString(traceIDBytes) spanIDBytes := make([]byte, 8) spanIDHex := hex.EncodeToString(spanIDBytes) spanConfig := trace.SpanContextConfig{} spanConfig.TraceID, _ = trace.TraceIDFromHex(traceIDHex) spanConfig.SpanID, _ = trace.SpanIDFromHex(spanIDHex) spanConfig.TraceFlags = trace.FlagsSampled ctx := trace.ContextWithRemoteSpanContext(context.Background(), trace.NewSpanContext(spanConfig)) c := &testCarrier{ Values: make(map[string]string), } p.Inject(ctx, c) require.Equal(t, 0, len(c.Values), "some headers were saved") }) t.Run("injects flags if no flags specified", func(t *testing.T) { traceIDBytes := make([]byte, 16) rand.Read(traceIDBytes) traceIDHex := hex.EncodeToString(traceIDBytes) spanIDBytes := make([]byte, 8) rand.Read(spanIDBytes) spanIDHex := hex.EncodeToString(spanIDBytes) spanConfig := trace.SpanContextConfig{} spanConfig.TraceID, _ = trace.TraceIDFromHex(traceIDHex) spanConfig.SpanID, _ = trace.SpanIDFromHex(spanIDHex) ctx := trace.ContextWithRemoteSpanContext(context.Background(), trace.NewSpanContext(spanConfig)) c := &testCarrier{ Values: make(map[string]string), } p.Inject(ctx, c) require.Equal(t, 3, len(c.Values), "not all headers were saved") require.Equal(t, traceIDHex, c.Values[traceIDHeader], "unexpected trace id") require.Equal(t, spanIDHex, c.Values[spanIDHeader], "unexpected span id") require.Equal(t, "0", c.Values[flagsHeader], "unexpected flags") }) } func TestPropagator_Extract(t *testing.T) { t.Run("extracts if set", func(t *testing.T) { c := &testCarrier{ Values: make(map[string]string), } traceIDBytes := make([]byte, 16) rand.Read(traceIDBytes) traceIDHex := hex.EncodeToString(traceIDBytes) c.Values[traceIDHeader] = traceIDHex spanIDBytes := make([]byte, 8) rand.Read(spanIDBytes) spanIDHex := hex.EncodeToString(spanIDBytes) c.Values[spanIDHeader] = spanIDHex c.Values[flagsHeader] = "1" ctx := p.Extract(context.Background(), c) sc := trace.SpanFromContext(ctx).SpanContext() require.True(t, sc.HasTraceID(), "trace_id was not set") require.Equal(t, traceIDHex, sc.TraceID().String(), "trace_id doesn't match") require.True(t, sc.HasSpanID(), "span_id was not set") require.Equal(t, spanIDHex, sc.SpanID().String(), "span_id doesn't match") require.True(t, sc.IsSampled(), "sampled was not set") }) t.Run("not extracts if only trace_id defined", func(t *testing.T) { c := &testCarrier{ Values: make(map[string]string), } traceIDBytes := make([]byte, 16) rand.Read(traceIDBytes) traceIDHex := hex.EncodeToString(traceIDBytes) c.Values[traceIDHeader] = traceIDHex c.Values[flagsHeader] = "1" ctx := p.Extract(context.Background(), c) sc := trace.SpanFromContext(ctx).SpanContext() require.False(t, sc.HasTraceID(), "trace_id was set") require.False(t, sc.HasSpanID(), "span_id was set") require.False(t, sc.IsSampled(), "sampled was set") }) t.Run("not extracts if only span_id defined", func(t *testing.T) { c := &testCarrier{ Values: make(map[string]string), } spanIDBytes := make([]byte, 8) rand.Read(spanIDBytes) spanIDHex := hex.EncodeToString(spanIDBytes) c.Values[spanIDHeader] = spanIDHex c.Values[flagsHeader] = "1" ctx := p.Extract(context.Background(), c) sc := trace.SpanFromContext(ctx).SpanContext() require.False(t, sc.HasTraceID(), "trace_id was set") require.False(t, sc.HasSpanID(), "span_id was set") require.False(t, sc.IsSampled(), "sampled was set") }) t.Run("not extracts if trace_id is in invalid", func(t *testing.T) { c := &testCarrier{ Values: make(map[string]string), } c.Values[traceIDHeader] = "loren ipsum" spanIDBytes := make([]byte, 8) rand.Read(spanIDBytes) spanIDHex := hex.EncodeToString(spanIDBytes) c.Values[spanIDHeader] = spanIDHex c.Values[flagsHeader] = "1" ctx := p.Extract(context.Background(), c) sc := trace.SpanFromContext(ctx).SpanContext() require.False(t, sc.HasTraceID(), "trace_id was set") require.False(t, sc.HasSpanID(), "span_id was set") require.False(t, sc.IsSampled(), "sampled was set") }) t.Run("not extracts if span_id is invalid", func(t *testing.T) { c := &testCarrier{ Values: make(map[string]string), } c.Values[spanIDHeader] = "loren ipsum" traceIDBytes := make([]byte, 16) rand.Read(traceIDBytes) traceIDHex := hex.EncodeToString(traceIDBytes) c.Values[traceIDHeader] = traceIDHex c.Values[flagsHeader] = "1" ctx := p.Extract(context.Background(), c) sc := trace.SpanFromContext(ctx).SpanContext() require.False(t, sc.HasTraceID(), "trace_id was set") require.False(t, sc.HasSpanID(), "span_id was set") require.False(t, sc.IsSampled(), "sampled was set") }) t.Run("not extracts if flags is invalid", func(t *testing.T) { c := &testCarrier{ Values: make(map[string]string), } traceIDBytes := make([]byte, 16) rand.Read(traceIDBytes) traceIDHex := hex.EncodeToString(traceIDBytes) c.Values[traceIDHeader] = traceIDHex spanIDBytes := make([]byte, 8) rand.Read(spanIDBytes) spanIDHex := hex.EncodeToString(spanIDBytes) c.Values[spanIDHeader] = spanIDHex c.Values[flagsHeader] = "loren ipsum" ctx := p.Extract(context.Background(), c) sc := trace.SpanFromContext(ctx).SpanContext() require.False(t, sc.HasTraceID(), "trace_id was set") require.False(t, sc.HasSpanID(), "span_id was set") require.False(t, sc.IsSampled(), "sampled was set") }) }