252 lines
6.7 KiB
Go
252 lines
6.7 KiB
Go
package rfc2136
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/go-acme/lego/v3/challenge/dns01"
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
const (
|
|
fakeDomain = "123456789.www.example.com"
|
|
fakeKeyAuth = "123d=="
|
|
fakeValue = "Now36o-3BmlB623-0c1qCIUmgWVVmDJb88KGl24pqpo"
|
|
fakeFqdn = "_acme-challenge.123456789.www.example.com."
|
|
fakeZone = "example.com."
|
|
fakeTTL = 120
|
|
fakeTsigKey = "example.com."
|
|
fakeTsigSecret = "IwBTJx9wrDp4Y1RyC3H0gA=="
|
|
)
|
|
|
|
func TestCanaryLocalTestServer(t *testing.T) {
|
|
dns01.ClearFqdnCache()
|
|
dns.HandleFunc("example.com.", serverHandlerHello)
|
|
defer dns.HandleRemove("example.com.")
|
|
|
|
server, addr, err := runLocalDNSTestServer(false)
|
|
require.NoError(t, err, "Failed to start test server")
|
|
defer func() { _ = server.Shutdown() }()
|
|
|
|
c := new(dns.Client)
|
|
m := new(dns.Msg)
|
|
|
|
m.SetQuestion("example.com.", dns.TypeTXT)
|
|
|
|
r, _, err := c.Exchange(m, addr)
|
|
require.NoError(t, err, "Failed to communicate with test server")
|
|
assert.Len(t, r.Extra, 1, "Failed to communicate with test server")
|
|
|
|
txt := r.Extra[0].(*dns.TXT).Txt[0]
|
|
assert.Equal(t, "Hello world", txt)
|
|
}
|
|
|
|
func TestServerSuccess(t *testing.T) {
|
|
dns01.ClearFqdnCache()
|
|
dns.HandleFunc(fakeZone, serverHandlerReturnSuccess)
|
|
defer dns.HandleRemove(fakeZone)
|
|
|
|
server, addr, err := runLocalDNSTestServer(false)
|
|
require.NoError(t, err, "Failed to start test server")
|
|
defer func() { _ = server.Shutdown() }()
|
|
|
|
config := NewDefaultConfig()
|
|
config.Nameserver = addr
|
|
|
|
provider, err := NewDNSProviderConfig(config)
|
|
require.NoError(t, err)
|
|
|
|
err = provider.Present(fakeDomain, "", fakeKeyAuth)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestServerError(t *testing.T) {
|
|
dns01.ClearFqdnCache()
|
|
dns.HandleFunc(fakeZone, serverHandlerReturnErr)
|
|
defer dns.HandleRemove(fakeZone)
|
|
|
|
server, addr, err := runLocalDNSTestServer(false)
|
|
require.NoError(t, err, "Failed to start test server")
|
|
defer func() { _ = server.Shutdown() }()
|
|
|
|
config := NewDefaultConfig()
|
|
config.Nameserver = addr
|
|
|
|
provider, err := NewDNSProviderConfig(config)
|
|
require.NoError(t, err)
|
|
|
|
err = provider.Present(fakeDomain, "", fakeKeyAuth)
|
|
require.Error(t, err)
|
|
if !strings.Contains(err.Error(), "NOTZONE") {
|
|
t.Errorf("Expected Present() to return an error with the 'NOTZONE' rcode string but it did not: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestTsigClient(t *testing.T) {
|
|
dns01.ClearFqdnCache()
|
|
dns.HandleFunc(fakeZone, serverHandlerReturnSuccess)
|
|
defer dns.HandleRemove(fakeZone)
|
|
|
|
server, addr, err := runLocalDNSTestServer(true)
|
|
require.NoError(t, err, "Failed to start test server")
|
|
defer func() { _ = server.Shutdown() }()
|
|
|
|
config := NewDefaultConfig()
|
|
config.Nameserver = addr
|
|
config.TSIGKey = fakeTsigKey
|
|
config.TSIGSecret = fakeTsigSecret
|
|
|
|
provider, err := NewDNSProviderConfig(config)
|
|
require.NoError(t, err)
|
|
|
|
err = provider.Present(fakeDomain, "", fakeKeyAuth)
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestValidUpdatePacket(t *testing.T) {
|
|
reqChan := make(chan *dns.Msg, 10)
|
|
|
|
dns01.ClearFqdnCache()
|
|
dns.HandleFunc(fakeZone, serverHandlerPassBackRequest(reqChan))
|
|
defer dns.HandleRemove(fakeZone)
|
|
|
|
server, addr, err := runLocalDNSTestServer(false)
|
|
require.NoError(t, err, "Failed to start test server")
|
|
defer func() { _ = server.Shutdown() }()
|
|
|
|
txtRR, _ := dns.NewRR(fmt.Sprintf("%s %d IN TXT %s", fakeFqdn, fakeTTL, fakeValue))
|
|
rrs := []dns.RR{txtRR}
|
|
m := new(dns.Msg)
|
|
m.SetUpdate(fakeZone)
|
|
m.RemoveRRset(rrs)
|
|
m.Insert(rrs)
|
|
expectStr := m.String()
|
|
|
|
expect, err := m.Pack()
|
|
require.NoError(t, err, "error packing")
|
|
|
|
config := NewDefaultConfig()
|
|
config.Nameserver = addr
|
|
|
|
provider, err := NewDNSProviderConfig(config)
|
|
require.NoError(t, err)
|
|
|
|
err = provider.Present(fakeDomain, "", "1234d==")
|
|
require.NoError(t, err)
|
|
|
|
rcvMsg := <-reqChan
|
|
rcvMsg.Id = m.Id
|
|
|
|
actual, err := rcvMsg.Pack()
|
|
require.NoError(t, err, "error packing")
|
|
|
|
if !bytes.Equal(actual, expect) {
|
|
tmp := new(dns.Msg)
|
|
if err := tmp.Unpack(actual); err != nil {
|
|
t.Fatalf("Error unpacking actual msg: %v", err)
|
|
}
|
|
t.Errorf("Expected msg:\n%s", expectStr)
|
|
t.Errorf("Actual msg:\n%v", tmp)
|
|
}
|
|
}
|
|
|
|
func runLocalDNSTestServer(tsig bool) (*dns.Server, string, error) {
|
|
pc, err := net.ListenPacket("udp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
server := &dns.Server{
|
|
PacketConn: pc,
|
|
ReadTimeout: time.Hour,
|
|
WriteTimeout: time.Hour,
|
|
MsgAcceptFunc: func(dh dns.Header) dns.MsgAcceptAction {
|
|
// bypass defaultMsgAcceptFunc to allow dynamic update (https://github.com/miekg/dns/pull/830)
|
|
return dns.MsgAccept
|
|
},
|
|
}
|
|
|
|
if tsig {
|
|
server.TsigSecret = map[string]string{fakeTsigKey: fakeTsigSecret}
|
|
}
|
|
|
|
waitLock := sync.Mutex{}
|
|
waitLock.Lock()
|
|
server.NotifyStartedFunc = waitLock.Unlock
|
|
|
|
go func() {
|
|
_ = server.ActivateAndServe()
|
|
pc.Close()
|
|
}()
|
|
|
|
waitLock.Lock()
|
|
return server, pc.LocalAddr().String(), nil
|
|
}
|
|
|
|
func serverHandlerHello(w dns.ResponseWriter, req *dns.Msg) {
|
|
m := new(dns.Msg)
|
|
m.SetReply(req)
|
|
m.Extra = make([]dns.RR, 1)
|
|
m.Extra[0] = &dns.TXT{
|
|
Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0},
|
|
Txt: []string{"Hello world"},
|
|
}
|
|
_ = w.WriteMsg(m)
|
|
}
|
|
|
|
func serverHandlerReturnSuccess(w dns.ResponseWriter, req *dns.Msg) {
|
|
m := new(dns.Msg)
|
|
m.SetReply(req)
|
|
if req.Opcode == dns.OpcodeQuery && req.Question[0].Qtype == dns.TypeSOA && req.Question[0].Qclass == dns.ClassINET {
|
|
// Return SOA to appease findZoneByFqdn()
|
|
soaRR, _ := dns.NewRR(fmt.Sprintf("%s %d IN SOA ns1.%s admin.%s 2016022801 28800 7200 2419200 1200", fakeZone, fakeTTL, fakeZone, fakeZone))
|
|
m.Answer = []dns.RR{soaRR}
|
|
}
|
|
|
|
if t := req.IsTsig(); t != nil {
|
|
if w.TsigStatus() == nil {
|
|
// Validated
|
|
m.SetTsig(fakeZone, dns.HmacMD5, 300, time.Now().Unix())
|
|
}
|
|
}
|
|
|
|
_ = w.WriteMsg(m)
|
|
}
|
|
|
|
func serverHandlerReturnErr(w dns.ResponseWriter, req *dns.Msg) {
|
|
m := new(dns.Msg)
|
|
m.SetRcode(req, dns.RcodeNotZone)
|
|
_ = w.WriteMsg(m)
|
|
}
|
|
|
|
func serverHandlerPassBackRequest(reqChan chan *dns.Msg) func(w dns.ResponseWriter, req *dns.Msg) {
|
|
return func(w dns.ResponseWriter, req *dns.Msg) {
|
|
m := new(dns.Msg)
|
|
m.SetReply(req)
|
|
if req.Opcode == dns.OpcodeQuery && req.Question[0].Qtype == dns.TypeSOA && req.Question[0].Qclass == dns.ClassINET {
|
|
// Return SOA to appease findZoneByFqdn()
|
|
soaRR, _ := dns.NewRR(fmt.Sprintf("%s %d IN SOA ns1.%s admin.%s 2016022801 28800 7200 2419200 1200", fakeZone, fakeTTL, fakeZone, fakeZone))
|
|
m.Answer = []dns.RR{soaRR}
|
|
}
|
|
|
|
if t := req.IsTsig(); t != nil {
|
|
if w.TsigStatus() == nil {
|
|
// Validated
|
|
m.SetTsig(fakeZone, dns.HmacMD5, 300, time.Now().Unix())
|
|
}
|
|
}
|
|
|
|
_ = w.WriteMsg(m)
|
|
if req.Opcode != dns.OpcodeQuery || req.Question[0].Qtype != dns.TypeSOA || req.Question[0].Qclass != dns.ClassINET {
|
|
// Only talk back when it is not the SOA RR.
|
|
reqChan <- req
|
|
}
|
|
}
|
|
}
|