From 9ea8cde36e24cb68d6cb2c5db7b6b08e204db7a8 Mon Sep 17 00:00:00 2001 From: John Belamaric Date: Wed, 1 Mar 2017 10:41:54 -0500 Subject: [PATCH] Grpc tracing (#544) * checkpoint * Pass context through ServeDNS, enable gRPC tracing * Fix types and make tracer available to proxy. go fmt * Fix imports * Use the DoNotStartTrace option * Change to SpanFilter from DoNotStartTrace * Use new name (IncludeSpan) * Final names * Add tests; fix possible client/conn leaks in grpc * go fmt --- core/dnsserver/server.go | 5 +++- middleware/proxy/grpc.go | 28 +++++++++++++++--- middleware/proxy/grpc_test.go | 54 ++++++++++++++++++++++++++++++++++ middleware/proxy/proxy.go | 4 +++ middleware/proxy/setup.go | 3 +- middleware/trace/setup.go | 4 +-- middleware/trace/trace.go | 25 +++++++++++----- middleware/trace/trace_test.go | 33 +++++++++++++++++++++ 8 files changed, 140 insertions(+), 16 deletions(-) create mode 100644 middleware/proxy/grpc_test.go create mode 100644 middleware/trace/trace_test.go diff --git a/core/dnsserver/server.go b/core/dnsserver/server.go index bb8a80d79..4906571e6 100644 --- a/core/dnsserver/server.go +++ b/core/dnsserver/server.go @@ -155,6 +155,10 @@ func (s *Server) Address() string { return s.Addr } // defined in the request so that the correct zone // (configuration and middleware stack) will handle the request. func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + s.ServeDNSWithContext(context.Background(), w, r) +} + +func (s *Server) ServeDNSWithContext(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) { defer func() { // In case the user doesn't enable error middleware, we still // need to make sure that we stay alive up here @@ -171,7 +175,6 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { q := r.Question[0].Name b := make([]byte, len(q)) off, end := 0, false - ctx := context.Background() var dshandler *Config diff --git a/middleware/proxy/grpc.go b/middleware/proxy/grpc.go index aaf908d2a..c480d3cf2 100644 --- a/middleware/proxy/grpc.go +++ b/middleware/proxy/grpc.go @@ -5,16 +5,22 @@ import ( "crypto/tls" "log" + "github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc" + "github.com/coredns/coredns/middleware/proxy/pb" + "github.com/coredns/coredns/middleware/trace" "github.com/coredns/coredns/request" "github.com/miekg/dns" + + opentracing "github.com/opentracing/opentracing-go" + "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) type grpcClient struct { - dialOpt grpc.DialOption + dialOpts []grpc.DialOption clients map[string]pb.DnsServiceClient conns []*grpc.ClientConn upstream *staticUpstream @@ -24,9 +30,9 @@ func newGrpcClient(tls *tls.Config, u *staticUpstream) *grpcClient { g := &grpcClient{upstream: u} if tls == nil { - g.dialOpt = grpc.WithInsecure() + g.dialOpts = append(g.dialOpts, grpc.WithInsecure()) } else { - g.dialOpt = grpc.WithTransportCredentials(credentials.NewTLS(tls)) + g.dialOpts = append(g.dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tls))) } g.clients = map[string]pb.DnsServiceClient{} @@ -54,18 +60,32 @@ func (g *grpcClient) Exchange(ctx context.Context, addr string, state request.Re func (g *grpcClient) Protocol() string { return "grpc" } func (g *grpcClient) OnShutdown(p *Proxy) error { + g.clients = map[string]pb.DnsServiceClient{} for i, conn := range g.conns { err := conn.Close() if err != nil { log.Printf("[WARNING] Error closing connection %d: %s\n", i, err) } } + g.conns = []*grpc.ClientConn{} return nil } func (g *grpcClient) OnStartup(p *Proxy) error { + dialOpts := g.dialOpts + if p.Trace != nil { + if t, ok := p.Trace.(trace.Trace); ok { + onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp interface{}) bool { + return parentSpanCtx != nil + } + intercept := otgrpc.OpenTracingClientInterceptor(t.Tracer(), otgrpc.IncludingSpans(onlyIfParent)) + dialOpts = append(dialOpts, grpc.WithUnaryInterceptor(intercept)) + } else { + log.Printf("[WARNING] Wrong type for trace middleware reference: %s", p.Trace) + } + } for _, host := range g.upstream.Hosts { - conn, err := grpc.Dial(host.Name, g.dialOpt) + conn, err := grpc.Dial(host.Name, dialOpts...) if err != nil { log.Printf("[WARNING] Skipping gRPC host '%s' due to Dial error: %s\n", host.Name, err) } else { diff --git a/middleware/proxy/grpc_test.go b/middleware/proxy/grpc_test.go new file mode 100644 index 000000000..0eade58a9 --- /dev/null +++ b/middleware/proxy/grpc_test.go @@ -0,0 +1,54 @@ +package proxy + +import ( + "testing" + "time" +) + +func pool() []*UpstreamHost { + return []*UpstreamHost{ + { + Name: "localhost:10053", + }, + { + Name: "localhost:10054", + }, + } +} + +func TestStartupShutdown(t *testing.T) { + upstream := &staticUpstream{ + from: ".", + Hosts: pool(), + Policy: &Random{}, + Spray: nil, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + g := newGrpcClient(nil, upstream) + upstream.ex = g + + p := &Proxy{Trace: nil} + p.Upstreams = &[]Upstream{upstream} + + err := g.OnStartup(p) + if err != nil { + t.Errorf("Error starting grpc client exchanger: %s", err) + return + } + if len(g.clients) != len(pool()) { + t.Errorf("Expected %d grpc clients but found %d", len(pool()), len(g.clients)) + } + + err = g.OnShutdown(p) + if err != nil { + t.Errorf("Error stopping grpc client exchanger: %s", err) + return + } + if len(g.clients) != 0 { + t.Errorf("Shutdown didn't remove clients, found %d", len(g.clients)) + } + if len(g.conns) != 0 { + t.Errorf("Shutdown didn't remove conns, found %d", len(g.conns)) + } +} diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 090c070cb..9457fb2a1 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -28,6 +28,10 @@ type Proxy struct { // midway. Upstreams *[]Upstream + + // Trace is the Trace middleware, if it is installed + // This is used by the grpc exchanger to trace through the grpc calls + Trace middleware.Handler } // Upstream manages a pool of proxy upstream hosts. Select should return a diff --git a/middleware/proxy/setup.go b/middleware/proxy/setup.go index 3e4f262b7..36401188f 100644 --- a/middleware/proxy/setup.go +++ b/middleware/proxy/setup.go @@ -20,7 +20,8 @@ func setup(c *caddy.Controller) error { return middleware.Error("proxy", err) } - P := &Proxy{} + t := dnsserver.GetMiddleware(c, "trace") + P := &Proxy{Trace: t} dnsserver.GetConfig(c).AddMiddleware(func(next middleware.Handler) middleware.Handler { P.Next = next P.Upstreams = &upstreams diff --git a/middleware/trace/setup.go b/middleware/trace/setup.go index 4538e5f1d..a6eb8c340 100644 --- a/middleware/trace/setup.go +++ b/middleware/trace/setup.go @@ -34,9 +34,9 @@ func setup(c *caddy.Controller) error { return nil } -func traceParse(c *caddy.Controller) (*Trace, error) { +func traceParse(c *caddy.Controller) (*trace, error) { var ( - tr = &Trace{Endpoint: defEP, EndpointType: defEpType, every: 1, serviceName: defServiceName} + tr = &trace{Endpoint: defEP, EndpointType: defEpType, every: 1, serviceName: defServiceName} err error ) diff --git a/middleware/trace/trace.go b/middleware/trace/trace.go index 1b09e2914..3413fa681 100644 --- a/middleware/trace/trace.go +++ b/middleware/trace/trace.go @@ -15,12 +15,17 @@ import ( ) // Trace holds the tracer and endpoint info -type Trace struct { +type Trace interface { + middleware.Handler + Tracer() ot.Tracer +} + +type trace struct { Next middleware.Handler ServiceEndpoint string Endpoint string EndpointType string - Tracer ot.Tracer + tracer ot.Tracer serviceName string clientServer bool every uint64 @@ -28,8 +33,12 @@ type Trace struct { Once sync.Once } +func (t *trace) Tracer() ot.Tracer { + return t.tracer +} + // OnStartup sets up the tracer -func (t *Trace) OnStartup() error { +func (t *trace) OnStartup() error { var err error t.Once.Do(func() { switch t.EndpointType { @@ -42,7 +51,7 @@ func (t *Trace) OnStartup() error { return err } -func (t *Trace) setupZipkin() error { +func (t *trace) setupZipkin() error { collector, err := zipkin.NewHTTPCollector(t.Endpoint) if err != nil { @@ -50,7 +59,7 @@ func (t *Trace) setupZipkin() error { } recorder := zipkin.NewRecorder(collector, false, t.ServiceEndpoint, t.serviceName) - t.Tracer, err = zipkin.NewTracer(recorder, zipkin.ClientServerSameSpan(t.clientServer)) + t.tracer, err = zipkin.NewTracer(recorder, zipkin.ClientServerSameSpan(t.clientServer)) if err != nil { return err } @@ -58,12 +67,12 @@ func (t *Trace) setupZipkin() error { } // Name implements the Handler interface. -func (t *Trace) Name() string { +func (t *trace) Name() string { return "trace" } // ServeDNS implements the middleware.Handle interface. -func (t *Trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { +func (t *trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { trace := false if t.every > 0 { queryNr := atomic.AddUint64(&t.count, 1) @@ -73,7 +82,7 @@ func (t *Trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) } } if span := ot.SpanFromContext(ctx); span == nil && trace { - span := t.Tracer.StartSpan("servedns") + span := t.Tracer().StartSpan("servedns") defer span.Finish() ctx = ot.ContextWithSpan(ctx, span) } diff --git a/middleware/trace/trace_test.go b/middleware/trace/trace_test.go new file mode 100644 index 000000000..37dec2065 --- /dev/null +++ b/middleware/trace/trace_test.go @@ -0,0 +1,33 @@ +package trace + +import ( + "testing" + + "github.com/mholt/caddy" +) + +// CreateTestTrace creates a trace middleware to be used in tests +func CreateTestTrace(config string) (*caddy.Controller, *trace, error) { + c := caddy.NewTestController("dns", config) + m, err := traceParse(c) + return c, m, err +} + +func TestTrace(t *testing.T) { + _, m, err := CreateTestTrace(`trace`) + if err != nil { + t.Errorf("Error parsing test input: %s", err) + return + } + if m.Name() != "trace" { + t.Errorf("Wrong name from GetName: %s", m.Name()) + } + err = m.OnStartup() + if err != nil { + t.Errorf("Error starting tracing middleware: %s", err) + return + } + if m.Tracer() == nil { + t.Errorf("Error, no tracer created") + } +}