package tracing

import (
	"context"
	"fmt"
	"strconv"

	"go.opentelemetry.io/otel/propagation"
	"go.opentelemetry.io/otel/trace"
)

const (
	traceIDHeader = "x-frostfs-trace-id"
	spanIDHeader  = "x-frostfs-span-id"
	flagsHeader   = "x-frostfs-trace-flags"
)

const (
	flagsSampled = 1 << iota
)

// propagator serializes SpanContext to/from headers.
// x-frostfs-trace-id - TraceID, 16 bytes, hex-string (32 bytes).
// x-frostfs-span-id - SpanID, 8 bytes, hexstring (16 bytes).
// x-frostfs-trace-flags - trace flags (now sampled only).
type propagator struct{}

// Propagator is propagation.TextMapPropagator instance, used to extract/inject trace info from/to remote context.
var Propagator propagation.TextMapPropagator = &propagator{}

// Inject injects tracing info to carrier.
func (p *propagator) Inject(ctx context.Context, carrier propagation.TextMapCarrier) {
	sc := trace.SpanFromContext(ctx).SpanContext()
	if !sc.TraceID().IsValid() || !sc.SpanID().IsValid() {
		return
	}

	var flags int
	if sc.IsSampled() {
		flags = flags | flagsSampled
	}

	carrier.Set(traceIDHeader, sc.TraceID().String())
	carrier.Set(spanIDHeader, sc.SpanID().String())
	carrier.Set(flagsHeader, fmt.Sprintf("%x", flags))
}

// Extract extracts tracing info from carrier and returns context with tracing info.
// In case of error returns ctx.
func (p *propagator) Extract(ctx context.Context, carrier propagation.TextMapCarrier) context.Context {
	spanConfig := trace.SpanContextConfig{}
	var err error

	traceIDStr := carrier.Get(traceIDHeader)
	traceIDDefined := false
	if traceIDStr != "" {
		traceIDDefined = true
		spanConfig.TraceID, err = trace.TraceIDFromHex(traceIDStr)
		if err != nil {
			return ctx
		}
	}

	spanIDstr := carrier.Get(spanIDHeader)
	spanIDDefined := false
	if spanIDstr != "" {
		spanIDDefined = true
		spanConfig.SpanID, err = trace.SpanIDFromHex(spanIDstr)
		if err != nil {
			return ctx
		}
	}

	if traceIDDefined != spanIDDefined {
		return ctx //traceID + spanID must be defined OR no traceID and no spanID
	}

	flagsStr := carrier.Get(flagsHeader)
	if flagsStr != "" {
		var v int64
		v, err = strconv.ParseInt(flagsStr, 16, 32)
		if err != nil {
			return ctx
		}
		if v&flagsSampled == flagsSampled {
			spanConfig.TraceFlags = trace.FlagsSampled
		}
	}

	return trace.ContextWithRemoteSpanContext(ctx, trace.NewSpanContext(spanConfig))
}

// Fields returns the keys whose values are set with Inject.
func (p *propagator) Fields() []string {
	return []string{traceIDHeader, spanIDHeader, flagsHeader}
}