package middleware

import (
	"context"
	"net/http"
	"sync"

	"git.frostfs.info/TrueCloudLab/frostfs-observability/tracing"
	"go.opentelemetry.io/otel/attribute"
	semconv "go.opentelemetry.io/otel/semconv/v1.18.0"
	"go.opentelemetry.io/otel/trace"
)

// Tracing adds tracing support for requests.
// Must be placed after prepareRequest middleware.
func Tracing() Func {
	return func(h http.Handler) http.Handler {
		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			appCtx, span := StartHTTPServerSpan(r, "REQUEST S3")
			reqInfo := GetReqInfo(r.Context())
			reqInfo.TraceID = span.SpanContext().TraceID().String()
			lw := &traceResponseWriter{ResponseWriter: w, ctx: appCtx, span: span}
			h.ServeHTTP(lw, r.WithContext(appCtx))
		})
	}
}

type traceResponseWriter struct {
	sync.Once
	http.ResponseWriter

	ctx  context.Context
	span trace.Span
}

func (lrw *traceResponseWriter) WriteHeader(code int) {
	lrw.Do(func() {
		lrw.span.SetAttributes(
			semconv.HTTPStatusCode(code),
		)

		carrier := &httpResponseCarrier{resp: lrw.ResponseWriter}
		tracing.Propagator.Inject(lrw.ctx, carrier)

		lrw.ResponseWriter.WriteHeader(code)
		lrw.span.End()
	})
}

func (lrw *traceResponseWriter) Flush() {
	if f, ok := lrw.ResponseWriter.(http.Flusher); ok {
		f.Flush()
	}
}

type httpResponseCarrier struct {
	resp http.ResponseWriter
}

func (h httpResponseCarrier) Get(key string) string {
	return h.resp.Header().Get(key)
}

func (h httpResponseCarrier) Set(key string, value string) {
	h.resp.Header().Set(key, value)
}

func (h httpResponseCarrier) Keys() []string {
	result := make([]string, 0, len(h.resp.Header()))
	for key := range h.resp.Header() {
		result = append(result, key)
	}

	return result
}

type httpRequestCarrier struct {
	req *http.Request
}

func (c *httpRequestCarrier) Get(key string) string {
	bytes := c.req.Header.Get(key)
	if len(bytes) == 0 {
		return ""
	}
	return bytes
}

func (c *httpRequestCarrier) Set(key string, value string) {
	c.req.Response.Header.Set(key, value)
}

func (c *httpRequestCarrier) Keys() []string {
	result := make([]string, 0, len(c.req.Header))
	for key := range c.req.Header {
		result = append(result, key)
	}

	return result
}

func extractHTTPTraceInfo(ctx context.Context, req *http.Request) context.Context {
	if req == nil {
		return ctx
	}
	carrier := &httpRequestCarrier{req: req}
	return tracing.Propagator.Extract(ctx, carrier)
}

// StartHTTPServerSpan starts root HTTP server span.
func StartHTTPServerSpan(r *http.Request, operationName string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
	ctx := extractHTTPTraceInfo(r.Context(), r)
	opts = append(opts, trace.WithAttributes(
		attribute.String("s3.client_address", r.RemoteAddr),
		attribute.String("s3.path", r.Host),
		attribute.String("s3.request_id", GetRequestID(r.Context())),
		semconv.HTTPMethod(r.Method),
		semconv.RPCService("frostfs-s3-gw"),
		attribute.String("s3.query", r.RequestURI),
	), trace.WithSpanKind(trace.SpanKindServer))
	return tracing.StartSpanFromContext(ctx, operationName, opts...)
}