Grpc tracing (#544)

* checkpoint

* Pass context through ServeDNS, enable gRPC tracing

* Fix types and make tracer available to proxy. go fmt

* Fix imports

* Use the DoNotStartTrace option

* Change to SpanFilter from DoNotStartTrace

* Use new name (IncludeSpan)

* Final names

* Add tests; fix possible client/conn leaks in grpc

* go fmt
This commit is contained in:
John Belamaric 2017-03-01 10:41:54 -05:00 committed by GitHub
parent 0a4903571e
commit 9ea8cde36e
8 changed files with 140 additions and 16 deletions

View file

@ -155,6 +155,10 @@ func (s *Server) Address() string { return s.Addr }
// defined in the request so that the correct zone // defined in the request so that the correct zone
// (configuration and middleware stack) will handle the request. // (configuration and middleware stack) will handle the request.
func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
s.ServeDNSWithContext(context.Background(), w, r)
}
func (s *Server) ServeDNSWithContext(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) {
defer func() { defer func() {
// In case the user doesn't enable error middleware, we still // In case the user doesn't enable error middleware, we still
// need to make sure that we stay alive up here // need to make sure that we stay alive up here
@ -171,7 +175,6 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
q := r.Question[0].Name q := r.Question[0].Name
b := make([]byte, len(q)) b := make([]byte, len(q))
off, end := 0, false off, end := 0, false
ctx := context.Background()
var dshandler *Config var dshandler *Config

View file

@ -5,16 +5,22 @@ import (
"crypto/tls" "crypto/tls"
"log" "log"
"github.com/grpc-ecosystem/grpc-opentracing/go/otgrpc"
"github.com/coredns/coredns/middleware/proxy/pb" "github.com/coredns/coredns/middleware/proxy/pb"
"github.com/coredns/coredns/middleware/trace"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
opentracing "github.com/opentracing/opentracing-go"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
type grpcClient struct { type grpcClient struct {
dialOpt grpc.DialOption dialOpts []grpc.DialOption
clients map[string]pb.DnsServiceClient clients map[string]pb.DnsServiceClient
conns []*grpc.ClientConn conns []*grpc.ClientConn
upstream *staticUpstream upstream *staticUpstream
@ -24,9 +30,9 @@ func newGrpcClient(tls *tls.Config, u *staticUpstream) *grpcClient {
g := &grpcClient{upstream: u} g := &grpcClient{upstream: u}
if tls == nil { if tls == nil {
g.dialOpt = grpc.WithInsecure() g.dialOpts = append(g.dialOpts, grpc.WithInsecure())
} else { } else {
g.dialOpt = grpc.WithTransportCredentials(credentials.NewTLS(tls)) g.dialOpts = append(g.dialOpts, grpc.WithTransportCredentials(credentials.NewTLS(tls)))
} }
g.clients = map[string]pb.DnsServiceClient{} g.clients = map[string]pb.DnsServiceClient{}
@ -54,18 +60,32 @@ func (g *grpcClient) Exchange(ctx context.Context, addr string, state request.Re
func (g *grpcClient) Protocol() string { return "grpc" } func (g *grpcClient) Protocol() string { return "grpc" }
func (g *grpcClient) OnShutdown(p *Proxy) error { func (g *grpcClient) OnShutdown(p *Proxy) error {
g.clients = map[string]pb.DnsServiceClient{}
for i, conn := range g.conns { for i, conn := range g.conns {
err := conn.Close() err := conn.Close()
if err != nil { if err != nil {
log.Printf("[WARNING] Error closing connection %d: %s\n", i, err) log.Printf("[WARNING] Error closing connection %d: %s\n", i, err)
} }
} }
g.conns = []*grpc.ClientConn{}
return nil return nil
} }
func (g *grpcClient) OnStartup(p *Proxy) error { func (g *grpcClient) OnStartup(p *Proxy) error {
dialOpts := g.dialOpts
if p.Trace != nil {
if t, ok := p.Trace.(trace.Trace); ok {
onlyIfParent := func(parentSpanCtx opentracing.SpanContext, method string, req, resp interface{}) bool {
return parentSpanCtx != nil
}
intercept := otgrpc.OpenTracingClientInterceptor(t.Tracer(), otgrpc.IncludingSpans(onlyIfParent))
dialOpts = append(dialOpts, grpc.WithUnaryInterceptor(intercept))
} else {
log.Printf("[WARNING] Wrong type for trace middleware reference: %s", p.Trace)
}
}
for _, host := range g.upstream.Hosts { for _, host := range g.upstream.Hosts {
conn, err := grpc.Dial(host.Name, g.dialOpt) conn, err := grpc.Dial(host.Name, dialOpts...)
if err != nil { if err != nil {
log.Printf("[WARNING] Skipping gRPC host '%s' due to Dial error: %s\n", host.Name, err) log.Printf("[WARNING] Skipping gRPC host '%s' due to Dial error: %s\n", host.Name, err)
} else { } else {

View file

@ -0,0 +1,54 @@
package proxy
import (
"testing"
"time"
)
func pool() []*UpstreamHost {
return []*UpstreamHost{
{
Name: "localhost:10053",
},
{
Name: "localhost:10054",
},
}
}
func TestStartupShutdown(t *testing.T) {
upstream := &staticUpstream{
from: ".",
Hosts: pool(),
Policy: &Random{},
Spray: nil,
FailTimeout: 10 * time.Second,
MaxFails: 1,
}
g := newGrpcClient(nil, upstream)
upstream.ex = g
p := &Proxy{Trace: nil}
p.Upstreams = &[]Upstream{upstream}
err := g.OnStartup(p)
if err != nil {
t.Errorf("Error starting grpc client exchanger: %s", err)
return
}
if len(g.clients) != len(pool()) {
t.Errorf("Expected %d grpc clients but found %d", len(pool()), len(g.clients))
}
err = g.OnShutdown(p)
if err != nil {
t.Errorf("Error stopping grpc client exchanger: %s", err)
return
}
if len(g.clients) != 0 {
t.Errorf("Shutdown didn't remove clients, found %d", len(g.clients))
}
if len(g.conns) != 0 {
t.Errorf("Shutdown didn't remove conns, found %d", len(g.conns))
}
}

View file

@ -28,6 +28,10 @@ type Proxy struct {
// midway. // midway.
Upstreams *[]Upstream Upstreams *[]Upstream
// Trace is the Trace middleware, if it is installed
// This is used by the grpc exchanger to trace through the grpc calls
Trace middleware.Handler
} }
// Upstream manages a pool of proxy upstream hosts. Select should return a // Upstream manages a pool of proxy upstream hosts. Select should return a

View file

@ -20,7 +20,8 @@ func setup(c *caddy.Controller) error {
return middleware.Error("proxy", err) return middleware.Error("proxy", err)
} }
P := &Proxy{} t := dnsserver.GetMiddleware(c, "trace")
P := &Proxy{Trace: t}
dnsserver.GetConfig(c).AddMiddleware(func(next middleware.Handler) middleware.Handler { dnsserver.GetConfig(c).AddMiddleware(func(next middleware.Handler) middleware.Handler {
P.Next = next P.Next = next
P.Upstreams = &upstreams P.Upstreams = &upstreams

View file

@ -34,9 +34,9 @@ func setup(c *caddy.Controller) error {
return nil return nil
} }
func traceParse(c *caddy.Controller) (*Trace, error) { func traceParse(c *caddy.Controller) (*trace, error) {
var ( var (
tr = &Trace{Endpoint: defEP, EndpointType: defEpType, every: 1, serviceName: defServiceName} tr = &trace{Endpoint: defEP, EndpointType: defEpType, every: 1, serviceName: defServiceName}
err error err error
) )

View file

@ -15,12 +15,17 @@ import (
) )
// Trace holds the tracer and endpoint info // Trace holds the tracer and endpoint info
type Trace struct { type Trace interface {
middleware.Handler
Tracer() ot.Tracer
}
type trace struct {
Next middleware.Handler Next middleware.Handler
ServiceEndpoint string ServiceEndpoint string
Endpoint string Endpoint string
EndpointType string EndpointType string
Tracer ot.Tracer tracer ot.Tracer
serviceName string serviceName string
clientServer bool clientServer bool
every uint64 every uint64
@ -28,8 +33,12 @@ type Trace struct {
Once sync.Once Once sync.Once
} }
func (t *trace) Tracer() ot.Tracer {
return t.tracer
}
// OnStartup sets up the tracer // OnStartup sets up the tracer
func (t *Trace) OnStartup() error { func (t *trace) OnStartup() error {
var err error var err error
t.Once.Do(func() { t.Once.Do(func() {
switch t.EndpointType { switch t.EndpointType {
@ -42,7 +51,7 @@ func (t *Trace) OnStartup() error {
return err return err
} }
func (t *Trace) setupZipkin() error { func (t *trace) setupZipkin() error {
collector, err := zipkin.NewHTTPCollector(t.Endpoint) collector, err := zipkin.NewHTTPCollector(t.Endpoint)
if err != nil { if err != nil {
@ -50,7 +59,7 @@ func (t *Trace) setupZipkin() error {
} }
recorder := zipkin.NewRecorder(collector, false, t.ServiceEndpoint, t.serviceName) recorder := zipkin.NewRecorder(collector, false, t.ServiceEndpoint, t.serviceName)
t.Tracer, err = zipkin.NewTracer(recorder, zipkin.ClientServerSameSpan(t.clientServer)) t.tracer, err = zipkin.NewTracer(recorder, zipkin.ClientServerSameSpan(t.clientServer))
if err != nil { if err != nil {
return err return err
} }
@ -58,12 +67,12 @@ func (t *Trace) setupZipkin() error {
} }
// Name implements the Handler interface. // Name implements the Handler interface.
func (t *Trace) Name() string { func (t *trace) Name() string {
return "trace" return "trace"
} }
// ServeDNS implements the middleware.Handle interface. // ServeDNS implements the middleware.Handle interface.
func (t *Trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (t *trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
trace := false trace := false
if t.every > 0 { if t.every > 0 {
queryNr := atomic.AddUint64(&t.count, 1) queryNr := atomic.AddUint64(&t.count, 1)
@ -73,7 +82,7 @@ func (t *Trace) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg)
} }
} }
if span := ot.SpanFromContext(ctx); span == nil && trace { if span := ot.SpanFromContext(ctx); span == nil && trace {
span := t.Tracer.StartSpan("servedns") span := t.Tracer().StartSpan("servedns")
defer span.Finish() defer span.Finish()
ctx = ot.ContextWithSpan(ctx, span) ctx = ot.ContextWithSpan(ctx, span)
} }

View file

@ -0,0 +1,33 @@
package trace
import (
"testing"
"github.com/mholt/caddy"
)
// CreateTestTrace creates a trace middleware to be used in tests
func CreateTestTrace(config string) (*caddy.Controller, *trace, error) {
c := caddy.NewTestController("dns", config)
m, err := traceParse(c)
return c, m, err
}
func TestTrace(t *testing.T) {
_, m, err := CreateTestTrace(`trace`)
if err != nil {
t.Errorf("Error parsing test input: %s", err)
return
}
if m.Name() != "trace" {
t.Errorf("Wrong name from GetName: %s", m.Name())
}
err = m.OnStartup()
if err != nil {
t.Errorf("Error starting tracing middleware: %s", err)
return
}
if m.Tracer() == nil {
t.Errorf("Error, no tracer created")
}
}