diff --git a/plugin/grpc/README.md b/plugin/grpc/README.md index 5e6148da9..c2e8b34ea 100644 --- a/plugin/grpc/README.md +++ b/plugin/grpc/README.md @@ -129,7 +129,15 @@ Or with multiple upstreams from the same provider } ~~~ +Forward requests to a local upstream listening on a Unix domain socket. + +~~~ corefile +. { + grpc . unix:///path/to/grpc.sock +} +~~~ + ## Bugs The TLS config is global for the whole grpc proxy if you need a different `tls_servername` for -different upstreams you're out of luck. +different upstreams you're out of luck. \ No newline at end of file diff --git a/plugin/grpc/proxy_test.go b/plugin/grpc/proxy_test.go index cc4ebec82..534fde3d7 100644 --- a/plugin/grpc/proxy_test.go +++ b/plugin/grpc/proxy_test.go @@ -3,9 +3,15 @@ package grpc import ( "context" "errors" + "net" + "os" + "path" "testing" + "github.com/coredns/caddy" "github.com/coredns/coredns/pb" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/test" "github.com/miekg/dns" "google.golang.org/grpc" @@ -64,3 +70,56 @@ type testServiceClient struct { func (m testServiceClient) Query(ctx context.Context, in *pb.DnsPacket, opts ...grpc.CallOption) (*pb.DnsPacket, error) { return m.dnsPacket, m.err } + +func TestProxyUnix(t *testing.T) { + tdir, err := os.MkdirTemp("", "tmp*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tdir) + + fd := path.Join(tdir, "test.grpc") + listener, err := net.Listen("unix", fd) + if err != nil { + t.Fatal("Failed to listen: ", err) + } + defer listener.Close() + + server := grpc.NewServer() + pb.RegisterDnsServiceServer(server, &grpcDnsServiceServer{}) + + go server.Serve(listener) + defer server.Stop() + + c := caddy.NewTestController("dns", "grpc . unix://"+fd) + g, err := parseGRPC(c) + + if err != nil { + t.Errorf("Failed to create forwarder: %s", err) + } + + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + if _, err := g.ServeDNS(context.TODO(), rec, m); err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + if x := rec.Msg.Answer[0].Header().Name; x != "example.org." { + t.Errorf("Expected %s, got %s", "example.org.", x) + } +} + +type grpcDnsServiceServer struct { + pb.UnimplementedDnsServiceServer +} + +func (*grpcDnsServiceServer) Query(ctx context.Context, in *pb.DnsPacket) (*pb.DnsPacket, error) { + msg := &dns.Msg{} + msg.Unpack(in.GetMsg()) + answer := new(dns.Msg) + answer.Answer = append(answer.Answer, test.A("example.org. IN A 127.0.0.1")) + answer.SetRcode(msg, dns.RcodeSuccess) + buf, _ := answer.Pack() + return &pb.DnsPacket{Msg: buf}, nil +} diff --git a/plugin/grpc/setup_test.go b/plugin/grpc/setup_test.go index 1d9e93b7f..f142099c3 100644 --- a/plugin/grpc/setup_test.go +++ b/plugin/grpc/setup_test.go @@ -25,6 +25,7 @@ func TestSetup(t *testing.T) { {"grpc . 127.0.0.1:8080", false, ".", nil, ""}, {"grpc . [::1]:53", false, ".", nil, ""}, {"grpc . [2003::1]:53", false, ".", nil, ""}, + {"grpc . unix:///var/run/g.sock", false, ".", nil, ""}, // negative {"grpc . a27.0.0.1", true, "", nil, "not an IP"}, {"grpc . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, "unknown property"}, diff --git a/plugin/pkg/parse/host.go b/plugin/pkg/parse/host.go index 9206a033d..c396dc853 100644 --- a/plugin/pkg/parse/host.go +++ b/plugin/pkg/parse/host.go @@ -33,6 +33,14 @@ func HostPortOrFile(s ...string) ([]string, error) { var servers []string for _, h := range s { trans, host := Transport(h) + if len(host) == 0 { + return servers, fmt.Errorf("invalid address: %q", h) + } + + if trans == transport.UNIX { + servers = append(servers, trans+"://"+host) + continue + } addr, _, err := net.SplitHostPort(host) diff --git a/plugin/pkg/parse/host_test.go b/plugin/pkg/parse/host_test.go index 611f8284f..0b5f6f1ff 100644 --- a/plugin/pkg/parse/host_test.go +++ b/plugin/pkg/parse/host_test.go @@ -58,6 +58,16 @@ func TestHostPortOrFile(t *testing.T) { "", true, }, + { + "unix:///var/run/g.sock", + "unix:///var/run/g.sock", + false, + }, + { + "unix://", + "", + true, + }, } err := os.WriteFile("resolv.conf", []byte("nameserver 127.0.0.1\n"), 0600) diff --git a/plugin/pkg/parse/transport.go b/plugin/pkg/parse/transport.go index d632120d7..0da640856 100644 --- a/plugin/pkg/parse/transport.go +++ b/plugin/pkg/parse/transport.go @@ -27,6 +27,9 @@ func Transport(s string) (trans string, addr string) { s = s[len(transport.HTTPS+"://"):] return transport.HTTPS, s + case strings.HasPrefix(s, transport.UNIX+"://"): + s = s[len(transport.UNIX+"://"):] + return transport.UNIX, s } return transport.DNS, s diff --git a/plugin/pkg/transport/transport.go b/plugin/pkg/transport/transport.go index 85b3bee5f..e23b6d647 100644 --- a/plugin/pkg/transport/transport.go +++ b/plugin/pkg/transport/transport.go @@ -6,6 +6,7 @@ const ( TLS = "tls" GRPC = "grpc" HTTPS = "https" + UNIX = "unix" ) // Port numbers for the various transports.