coredns/plugin/tsig/tsig_test.go
Chris O'Haver 68e141eff2
plugin/tsig: new plugin TSIG (#4957)
* expose tsig secrets via dnsserver.Config
* add tsig plugin

Signed-off-by: Chris O'Haver <cohaver@infoblox.com>
2022-06-27 15:48:34 -04:00

255 lines
5.7 KiB
Go

package tsig
import (
"context"
"fmt"
"testing"
"time"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request"
"github.com/miekg/dns"
)
func TestServeDNS(t *testing.T) {
cases := []struct {
zones []string
reqTypes qTypes
qType uint16
qTsig, all bool
expectRcode int
expectTsig bool
statusError bool
}{
{
zones: []string{"."},
all: true,
qType: dns.TypeA,
qTsig: true,
expectRcode: dns.RcodeSuccess,
expectTsig: true,
},
{
zones: []string{"."},
all: true,
qType: dns.TypeA,
qTsig: false,
expectRcode: dns.RcodeRefused,
expectTsig: false,
},
{
zones: []string{"another.domain."},
all: true,
qType: dns.TypeA,
qTsig: false,
expectRcode: dns.RcodeSuccess,
expectTsig: false,
},
{
zones: []string{"another.domain."},
all: true,
qType: dns.TypeA,
qTsig: true,
expectRcode: dns.RcodeSuccess,
expectTsig: false,
},
{
zones: []string{"."},
reqTypes: qTypes{dns.TypeAXFR: {}},
qType: dns.TypeAXFR,
qTsig: true,
expectRcode: dns.RcodeSuccess,
expectTsig: true,
},
{
zones: []string{"."},
reqTypes: qTypes{},
qType: dns.TypeA,
qTsig: false,
expectRcode: dns.RcodeSuccess,
expectTsig: false,
},
{
zones: []string{"."},
reqTypes: qTypes{},
qType: dns.TypeA,
qTsig: true,
expectRcode: dns.RcodeSuccess,
expectTsig: true,
},
{
zones: []string{"."},
all: true,
qType: dns.TypeA,
qTsig: true,
expectRcode: dns.RcodeNotAuth,
expectTsig: true,
statusError: true,
},
}
for i, tc := range cases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
tsig := TSIGServer{
Zones: tc.zones,
all: tc.all,
types: tc.reqTypes,
Next: testHandler(),
}
ctx := context.TODO()
var w *dnstest.Recorder
if tc.statusError {
w = dnstest.NewRecorder(&ErrWriter{err: dns.ErrSig})
} else {
w = dnstest.NewRecorder(&test.ResponseWriter{})
}
r := new(dns.Msg)
r.SetQuestion("test.example.", tc.qType)
if tc.qTsig {
r.SetTsig("test.key.", dns.HmacSHA256, 300, time.Now().Unix())
}
_, err := tsig.ServeDNS(ctx, w, r)
if err != nil {
t.Fatal(err)
}
if w.Msg.Rcode != tc.expectRcode {
t.Fatalf("expected rcode %v, got %v", tc.expectRcode, w.Msg.Rcode)
}
if ts := w.Msg.IsTsig(); ts == nil && tc.expectTsig {
t.Fatal("expected TSIG in response")
}
if ts := w.Msg.IsTsig(); ts != nil && !tc.expectTsig {
t.Fatal("expected no TSIG in response")
}
})
}
}
func TestServeDNSTsigErrors(t *testing.T) {
clientNow := time.Now().Unix()
cases := []struct {
desc string
tsigErr error
expectRcode int
expectError int
expectOtherLength int
expectTimeSigned int64
}{
{
desc: "Unknown Key",
tsigErr: dns.ErrSecret,
expectRcode: dns.RcodeNotAuth,
expectError: dns.RcodeBadKey,
expectOtherLength: 0,
expectTimeSigned: 0,
},
{
desc: "Bad Signature",
tsigErr: dns.ErrSig,
expectRcode: dns.RcodeNotAuth,
expectError: dns.RcodeBadSig,
expectOtherLength: 0,
expectTimeSigned: 0,
},
{
desc: "Bad Time",
tsigErr: dns.ErrTime,
expectRcode: dns.RcodeNotAuth,
expectError: dns.RcodeBadTime,
expectOtherLength: 6,
expectTimeSigned: clientNow,
},
}
tsig := TSIGServer{
Zones: []string{"."},
all: true,
Next: testHandler(),
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.TODO()
var w *dnstest.Recorder
w = dnstest.NewRecorder(&ErrWriter{err: tc.tsigErr})
r := new(dns.Msg)
r.SetQuestion("test.example.", dns.TypeA)
r.SetTsig("test.key.", dns.HmacSHA256, 300, clientNow)
// set a fake MAC and Size in request
rtsig := r.IsTsig()
rtsig.MAC = "0123456789012345678901234567890101234567890123456789012345678901"
rtsig.MACSize = 32
_, err := tsig.ServeDNS(ctx, w, r)
if err != nil {
t.Fatal(err)
}
if w.Msg.Rcode != tc.expectRcode {
t.Fatalf("expected rcode %v, got %v", tc.expectRcode, w.Msg.Rcode)
}
ts := w.Msg.IsTsig()
if ts == nil {
t.Fatal("expected TSIG in response")
}
if int(ts.Error) != tc.expectError {
t.Errorf("expected TSIG error code %v, got %v", tc.expectError, ts.Error)
}
if len(ts.OtherData)/2 != tc.expectOtherLength {
t.Errorf("expected Other of length %v, got %v", tc.expectOtherLength, len(ts.OtherData))
}
if int(ts.OtherLen) != tc.expectOtherLength {
t.Errorf("expected OtherLen %v, got %v", tc.expectOtherLength, ts.OtherLen)
}
if ts.TimeSigned != uint64(tc.expectTimeSigned) {
t.Errorf("expected TimeSigned to be %v, got %v", tc.expectTimeSigned, ts.TimeSigned)
}
})
}
}
func testHandler() test.HandlerFunc {
return func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := request.Request{W: w, Req: r}
qname := state.Name()
m := new(dns.Msg)
rcode := dns.RcodeServerFailure
if qname == "test.example." {
m.SetReply(r)
rr := test.A("test.example. 300 IN A 1.2.3.48")
m.Answer = []dns.RR{rr}
m.Authoritative = true
rcode = dns.RcodeSuccess
}
m.SetRcode(r, rcode)
w.WriteMsg(m)
return rcode, nil
}
}
// a test.ResponseWriter that always returns err as the TSIG status error
type ErrWriter struct {
err error
test.ResponseWriter
}
// TsigStatus always returns an error.
func (t *ErrWriter) TsigStatus() error { return t.err }