package tracing

import (
	"context"
	"encoding/hex"
	"math/rand"
	"testing"

	"github.com/stretchr/testify/require"
	"go.opentelemetry.io/otel/trace"
)

type testCarrier struct {
	Values map[string]string
}

func (c *testCarrier) Get(key string) string {
	return c.Values[key]
}

func (c *testCarrier) Set(key string, value string) {
	c.Values[key] = value
}

func (c *testCarrier) Keys() []string {
	res := make([]string, 0, len(c.Values))
	for k := range c.Values {
		res = append(res, k)
	}
	return res
}

var p = &propagator{}

func TestPropagator_Inject(t *testing.T) {
	t.Run("injects trace_id and span_id if valid", func(t *testing.T) {
		traceIDBytes := make([]byte, 16)
		rand.Read(traceIDBytes)
		traceIDHex := hex.EncodeToString(traceIDBytes)

		spanIDBytes := make([]byte, 8)
		rand.Read(spanIDBytes)
		spanIDHex := hex.EncodeToString(spanIDBytes)

		spanConfig := trace.SpanContextConfig{}
		spanConfig.TraceID, _ = trace.TraceIDFromHex(traceIDHex)
		spanConfig.SpanID, _ = trace.SpanIDFromHex(spanIDHex)
		spanConfig.TraceFlags = trace.FlagsSampled

		ctx := trace.ContextWithRemoteSpanContext(context.Background(), trace.NewSpanContext(spanConfig))
		c := &testCarrier{
			Values: make(map[string]string),
		}
		p.Inject(ctx, c)

		require.Equal(t, 3, len(c.Values), "not all headers were saved")
		require.Equal(t, traceIDHex, c.Values[traceIDHeader], "unexpected trace id")
		require.Equal(t, spanIDHex, c.Values[spanIDHeader], "unexpected span id")
		require.Equal(t, "1", c.Values[flagsHeader], "unexpected flags")
	})
	t.Run("doesn't injects if trace_id is invalid", func(t *testing.T) {
		traceIDBytes := make([]byte, 16)
		traceIDHex := hex.EncodeToString(traceIDBytes)

		spanIDBytes := make([]byte, 8)
		rand.Read(spanIDBytes)
		spanIDHex := hex.EncodeToString(spanIDBytes)

		spanConfig := trace.SpanContextConfig{}
		spanConfig.TraceID, _ = trace.TraceIDFromHex(traceIDHex)
		spanConfig.SpanID, _ = trace.SpanIDFromHex(spanIDHex)
		spanConfig.TraceFlags = trace.FlagsSampled

		ctx := trace.ContextWithRemoteSpanContext(context.Background(), trace.NewSpanContext(spanConfig))
		c := &testCarrier{
			Values: make(map[string]string),
		}
		p.Inject(ctx, c)

		require.Equal(t, 0, len(c.Values), "some headers were saved")
	})
	t.Run("doesn't injects if span_id is invalid", func(t *testing.T) {
		traceIDBytes := make([]byte, 16)
		rand.Read(traceIDBytes)
		traceIDHex := hex.EncodeToString(traceIDBytes)

		spanIDBytes := make([]byte, 8)
		spanIDHex := hex.EncodeToString(spanIDBytes)

		spanConfig := trace.SpanContextConfig{}
		spanConfig.TraceID, _ = trace.TraceIDFromHex(traceIDHex)
		spanConfig.SpanID, _ = trace.SpanIDFromHex(spanIDHex)
		spanConfig.TraceFlags = trace.FlagsSampled

		ctx := trace.ContextWithRemoteSpanContext(context.Background(), trace.NewSpanContext(spanConfig))
		c := &testCarrier{
			Values: make(map[string]string),
		}
		p.Inject(ctx, c)

		require.Equal(t, 0, len(c.Values), "some headers were saved")
	})
	t.Run("injects flags if no flags specified", func(t *testing.T) {
		traceIDBytes := make([]byte, 16)
		rand.Read(traceIDBytes)
		traceIDHex := hex.EncodeToString(traceIDBytes)

		spanIDBytes := make([]byte, 8)
		rand.Read(spanIDBytes)
		spanIDHex := hex.EncodeToString(spanIDBytes)

		spanConfig := trace.SpanContextConfig{}
		spanConfig.TraceID, _ = trace.TraceIDFromHex(traceIDHex)
		spanConfig.SpanID, _ = trace.SpanIDFromHex(spanIDHex)

		ctx := trace.ContextWithRemoteSpanContext(context.Background(), trace.NewSpanContext(spanConfig))
		c := &testCarrier{
			Values: make(map[string]string),
		}
		p.Inject(ctx, c)

		require.Equal(t, 3, len(c.Values), "not all headers were saved")
		require.Equal(t, traceIDHex, c.Values[traceIDHeader], "unexpected trace id")
		require.Equal(t, spanIDHex, c.Values[spanIDHeader], "unexpected span id")
		require.Equal(t, "0", c.Values[flagsHeader], "unexpected flags")
	})

}

func TestPropagator_Extract(t *testing.T) {
	t.Run("extracts if set", func(t *testing.T) {
		c := &testCarrier{
			Values: make(map[string]string),
		}

		traceIDBytes := make([]byte, 16)
		rand.Read(traceIDBytes)
		traceIDHex := hex.EncodeToString(traceIDBytes)
		c.Values[traceIDHeader] = traceIDHex

		spanIDBytes := make([]byte, 8)
		rand.Read(spanIDBytes)
		spanIDHex := hex.EncodeToString(spanIDBytes)
		c.Values[spanIDHeader] = spanIDHex

		c.Values[flagsHeader] = "1"

		ctx := p.Extract(context.Background(), c)

		sc := trace.SpanFromContext(ctx).SpanContext()
		require.True(t, sc.HasTraceID(), "trace_id was not set")
		require.Equal(t, traceIDHex, sc.TraceID().String(), "trace_id doesn't match")
		require.True(t, sc.HasSpanID(), "span_id was not set")
		require.Equal(t, spanIDHex, sc.SpanID().String(), "span_id doesn't match")
		require.True(t, sc.IsSampled(), "sampled was not set")
	})
	t.Run("not extracts if only trace_id defined", func(t *testing.T) {
		c := &testCarrier{
			Values: make(map[string]string),
		}

		traceIDBytes := make([]byte, 16)
		rand.Read(traceIDBytes)
		traceIDHex := hex.EncodeToString(traceIDBytes)
		c.Values[traceIDHeader] = traceIDHex
		c.Values[flagsHeader] = "1"

		ctx := p.Extract(context.Background(), c)

		sc := trace.SpanFromContext(ctx).SpanContext()
		require.False(t, sc.HasTraceID(), "trace_id was set")
		require.False(t, sc.HasSpanID(), "span_id was set")
		require.False(t, sc.IsSampled(), "sampled was set")
	})
	t.Run("not extracts if only span_id defined", func(t *testing.T) {
		c := &testCarrier{
			Values: make(map[string]string),
		}

		spanIDBytes := make([]byte, 8)
		rand.Read(spanIDBytes)
		spanIDHex := hex.EncodeToString(spanIDBytes)
		c.Values[spanIDHeader] = spanIDHex
		c.Values[flagsHeader] = "1"

		ctx := p.Extract(context.Background(), c)

		sc := trace.SpanFromContext(ctx).SpanContext()
		require.False(t, sc.HasTraceID(), "trace_id was set")
		require.False(t, sc.HasSpanID(), "span_id was set")
		require.False(t, sc.IsSampled(), "sampled was set")
	})
	t.Run("not extracts if trace_id is in invalid", func(t *testing.T) {
		c := &testCarrier{
			Values: make(map[string]string),
		}

		c.Values[traceIDHeader] = "loren ipsum"

		spanIDBytes := make([]byte, 8)
		rand.Read(spanIDBytes)
		spanIDHex := hex.EncodeToString(spanIDBytes)
		c.Values[spanIDHeader] = spanIDHex
		c.Values[flagsHeader] = "1"

		ctx := p.Extract(context.Background(), c)

		sc := trace.SpanFromContext(ctx).SpanContext()
		require.False(t, sc.HasTraceID(), "trace_id was set")
		require.False(t, sc.HasSpanID(), "span_id was set")
		require.False(t, sc.IsSampled(), "sampled was set")
	})
	t.Run("not extracts if span_id is invalid", func(t *testing.T) {
		c := &testCarrier{
			Values: make(map[string]string),
		}

		c.Values[spanIDHeader] = "loren ipsum"

		traceIDBytes := make([]byte, 16)
		rand.Read(traceIDBytes)
		traceIDHex := hex.EncodeToString(traceIDBytes)
		c.Values[traceIDHeader] = traceIDHex
		c.Values[flagsHeader] = "1"

		ctx := p.Extract(context.Background(), c)

		sc := trace.SpanFromContext(ctx).SpanContext()
		require.False(t, sc.HasTraceID(), "trace_id was set")
		require.False(t, sc.HasSpanID(), "span_id was set")
		require.False(t, sc.IsSampled(), "sampled was set")
	})
	t.Run("not extracts if flags is invalid", func(t *testing.T) {
		c := &testCarrier{
			Values: make(map[string]string),
		}

		traceIDBytes := make([]byte, 16)
		rand.Read(traceIDBytes)
		traceIDHex := hex.EncodeToString(traceIDBytes)
		c.Values[traceIDHeader] = traceIDHex

		spanIDBytes := make([]byte, 8)
		rand.Read(spanIDBytes)
		spanIDHex := hex.EncodeToString(spanIDBytes)
		c.Values[spanIDHeader] = spanIDHex

		c.Values[flagsHeader] = "loren ipsum"

		ctx := p.Extract(context.Background(), c)

		sc := trace.SpanFromContext(ctx).SpanContext()
		require.False(t, sc.HasTraceID(), "trace_id was set")
		require.False(t, sc.HasSpanID(), "span_id was set")
		require.False(t, sc.IsSampled(), "sampled was set")
	})
}