plugin/trace: read trace context info from headers for DOH (#5439)
Signed-off-by: Ondřej Benkovský <ondrej.benkovsky@jamf.com>
This commit is contained in:
parent
1290427645
commit
af4d84d915
3 changed files with 54 additions and 1 deletions
|
@ -27,6 +27,9 @@ type ServerHTTPS struct {
|
||||||
validRequest func(*http.Request) bool
|
validRequest func(*http.Request) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HTTPRequestKey is the context key for the current processed HTTP request (if current processed request was done over DOH)
|
||||||
|
type HTTPRequestKey struct{}
|
||||||
|
|
||||||
// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it.
|
// NewServerHTTPS returns a new CoreDNS HTTPS server and compiles all plugins in to it.
|
||||||
func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
|
func NewServerHTTPS(addr string, group []*Config) (*ServerHTTPS, error) {
|
||||||
s, err := NewServer(addr, group)
|
s, err := NewServer(addr, group)
|
||||||
|
@ -153,6 +156,7 @@ func (s *ServerHTTPS) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// We should expect a packet to be returned that we can send to the client.
|
// We should expect a packet to be returned that we can send to the client.
|
||||||
ctx := context.WithValue(context.Background(), Key{}, s.Server)
|
ctx := context.WithValue(context.Background(), Key{}, s.Server)
|
||||||
ctx = context.WithValue(ctx, LoopKey{}, 0)
|
ctx = context.WithValue(ctx, LoopKey{}, 0)
|
||||||
|
ctx = context.WithValue(ctx, HTTPRequestKey{}, r)
|
||||||
s.ServeDNS(ctx, dw, msg)
|
s.ServeDNS(ctx, dw, msg)
|
||||||
|
|
||||||
// See section 4.2.1 of RFC 8484.
|
// See section 4.2.1 of RFC 8484.
|
||||||
|
|
|
@ -4,9 +4,11 @@ package trace
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/coredns/coredns/core/dnsserver"
|
||||||
"github.com/coredns/coredns/plugin"
|
"github.com/coredns/coredns/plugin"
|
||||||
"github.com/coredns/coredns/plugin/metadata"
|
"github.com/coredns/coredns/plugin/metadata"
|
||||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||||
|
@ -140,8 +142,15 @@ func (t *trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
|
||||||
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
|
return plugin.NextOrFailure(t.Name(), t.Next, ctx, w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var spanCtx ot.SpanContext
|
||||||
|
if val := ctx.Value(dnsserver.HTTPRequestKey{}); val != nil {
|
||||||
|
if httpReq, ok := val.(*http.Request); ok {
|
||||||
|
spanCtx, _ = t.Tracer().Extract(ot.HTTPHeaders, ot.HTTPHeadersCarrier(httpReq.Header))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
req := request.Request{W: w, Req: r}
|
req := request.Request{W: w, Req: r}
|
||||||
span = t.Tracer().StartSpan(defaultTopLevelSpanName)
|
span = t.Tracer().StartSpan(defaultTopLevelSpanName, otext.RPCServerOption(spanCtx))
|
||||||
defer span.Finish()
|
defer span.Finish()
|
||||||
|
|
||||||
switch spanCtx := span.Context().(type) {
|
switch spanCtx := span.Context().(type) {
|
||||||
|
|
|
@ -3,9 +3,11 @@ package trace
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/coredns/caddy"
|
"github.com/coredns/caddy"
|
||||||
|
"github.com/coredns/coredns/core/dnsserver"
|
||||||
"github.com/coredns/coredns/plugin"
|
"github.com/coredns/coredns/plugin"
|
||||||
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
"github.com/coredns/coredns/plugin/pkg/dnstest"
|
||||||
"github.com/coredns/coredns/plugin/pkg/rcode"
|
"github.com/coredns/coredns/plugin/pkg/rcode"
|
||||||
|
@ -13,6 +15,7 @@ import (
|
||||||
"github.com/coredns/coredns/request"
|
"github.com/coredns/coredns/request"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"github.com/opentracing/opentracing-go"
|
||||||
"github.com/opentracing/opentracing-go/mocktracer"
|
"github.com/opentracing/opentracing-go/mocktracer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -131,3 +134,40 @@ func TestTrace(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTrace_DOH_TraceHeaderExtraction(t *testing.T) {
|
||||||
|
w := dnstest.NewRecorder(&test.ResponseWriter{})
|
||||||
|
m := mocktracer.New()
|
||||||
|
tr := &trace{
|
||||||
|
Next: test.HandlerFunc(func(_ context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
if plugin.ClientWrite(dns.RcodeSuccess) {
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
w.WriteMsg(m)
|
||||||
|
}
|
||||||
|
return dns.RcodeSuccess, nil
|
||||||
|
}),
|
||||||
|
every: 1,
|
||||||
|
tracer: m,
|
||||||
|
}
|
||||||
|
q := new(dns.Msg).SetQuestion("example.net.", dns.TypeA)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/dns-query", nil)
|
||||||
|
|
||||||
|
outsideSpan := m.StartSpan("test-header-span")
|
||||||
|
outsideSpan.Tracer().Inject(outsideSpan.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header))
|
||||||
|
defer outsideSpan.Finish()
|
||||||
|
|
||||||
|
ctx := context.TODO()
|
||||||
|
ctx = context.WithValue(ctx, dnsserver.HTTPRequestKey{}, req)
|
||||||
|
|
||||||
|
tr.ServeDNS(ctx, w, q)
|
||||||
|
|
||||||
|
fs := m.FinishedSpans()
|
||||||
|
rootCoreDNSspan := fs[1]
|
||||||
|
rootCoreDNSTraceID := rootCoreDNSspan.Context().(mocktracer.MockSpanContext).TraceID
|
||||||
|
outsideSpanTraceID := outsideSpan.Context().(mocktracer.MockSpanContext).TraceID
|
||||||
|
if rootCoreDNSTraceID != outsideSpanTraceID {
|
||||||
|
t.Errorf("Unexpected traceID: rootSpan.TraceID: want %v, got %v", rootCoreDNSTraceID, outsideSpanTraceID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue