diff --git a/core/dnsserver/server_https.go b/core/dnsserver/server_https.go index ba6097215..5c884e56b 100644 --- a/core/dnsserver/server_https.go +++ b/core/dnsserver/server_https.go @@ -27,6 +27,9 @@ type ServerHTTPS struct { validRequest func(*http.Request) bool } +// HTTPRequestKey is the context key for the current processed HTTP request (if current processed request was done over DOH) +type HTTPRequestKey struct{} + // NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it. func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) { s, err := NewServer(addr, group) @@ -153,6 +156,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) { // We should expect a packet to be returned that we can send to the client. ctx := context.WithValue(context.Background(), Key{}, s.Server) ctx = context.WithValue(ctx, LoopKey{}, 0) + ctx = context.WithValue(ctx, HTTPRequestKey{}, r) s.ServeDNS(ctx, dw, msg) // See section 4.2.1 of RFC 8484. diff --git a/plugin/trace/trace.go b/plugin/trace/trace.go index 87cb65e68..6bfd94dae 100644 --- a/plugin/trace/trace.go +++ b/plugin/trace/trace.go @@ -4,9 +4,11 @@ package trace import ( "context" "fmt" + "net/http" "sync" "sync/atomic" + "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/metadata" "github.com/coredns/coredns/plugin/pkg/dnstest" @@ -140,8 +142,15 @@ func (t *trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r) } + var spanCtx ot.SpanContext + if val := ctx.Value(dnsserver.HTTPRequestKey{}); val != nil { + if httpReq, ok := val.(*http.Request); ok { + spanCtx, _ = t.Tracer().Extract(ot.HTTPHeaders, ot.HTTPHeadersCarrier(httpReq.Header)) + } + } + req := request.Request{W: w, Req: r} - span = t.Tracer().StartSpan(defaultTopLevelSpanName) + span = t.Tracer().StartSpan(defaultTopLevelSpanName, otext.RPCServerOption(spanCtx)) defer span.Finish() switch spanCtx := span.Context().(type) { diff --git a/plugin/trace/trace_test.go b/plugin/trace/trace_test.go index dae546f8d..940eb6b02 100644 --- a/plugin/trace/trace_test.go +++ b/plugin/trace/trace_test.go @@ -3,9 +3,11 @@ package trace import ( "context" "errors" + "net/http/httptest" "testing" "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/plugin" "github.com/coredns/coredns/plugin/pkg/dnstest" "github.com/coredns/coredns/plugin/pkg/rcode" @@ -13,6 +15,7 @@ import ( "github.com/coredns/coredns/request" "github.com/miekg/dns" + "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/mocktracer" ) @@ -131,3 +134,40 @@ func TestTrace(t *testing.T) { }) } } + +func TestTrace_DOH_TraceHeaderExtraction(t *testing.T) { + w := dnstest.NewRecorder(&test.ResponseWriter{}) + m := mocktracer.New() + tr := &trace{ + Next: test.HandlerFunc(func(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + if plugin.ClientWrite(dns.RcodeSuccess) { + m := new(dns.Msg) + m.SetRcode(r, dns.RcodeSuccess) + w.WriteMsg(m) + } + return dns.RcodeSuccess, nil + }), + every: 1, + tracer: m, + } + q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA) + + req := httptest.NewRequest("POST", "/dns-query", nil) + + outsideSpan := m.StartSpan("test-header-span") + outsideSpan.Tracer().Inject(outsideSpan.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header)) + defer outsideSpan.Finish() + + ctx := context.TODO() + ctx = context.WithValue(ctx, dnsserver.HTTPRequestKey{}, req) + + tr.ServeDNS(ctx, w, q) + + fs := m.FinishedSpans() + rootCoreDNSspan := fs[1] + rootCoreDNSTraceID := rootCoreDNSspan.Context().(mocktracer.MockSpanContext).TraceID + outsideSpanTraceID := outsideSpan.Context().(mocktracer.MockSpanContext).TraceID + if rootCoreDNSTraceID != outsideSpanTraceID { + t.Errorf("Unexpected traceID: rootSpan.TraceID: want %v, got %v", rootCoreDNSTraceID, outsideSpanTraceID) + } +}