diff --git a/acme/dns_challenge.go b/acme/dns_challenge.go index 8d2a213b..198a1bcb 100644 --- a/acme/dns_challenge.go +++ b/acme/dns_challenge.go @@ -1 +1,140 @@ package acme + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/miekg/dns" +) + +type preCheckDNSFunc func() bool + +var preCheckDNS = func() bool { + return true +} + +var preCheckDNSFallbackCount = 5 + +// DNSProvider represents a service for creating dns records. +type DNSProvider interface { + // CreateTXT creates a TXT record + CreateTXTRecord(fqdn, value string, ttl int) error + RemoveTXTRecord(fqdn, value string, ttl int) error +} + +// dnsChallenge implements the dns-01 challenge according to ACME 7.5 +type dnsChallenge struct { + jws *jws + provider DNSProvider +} + +func (s *dnsChallenge) Solve(chlng challenge, domain string) error { + + logf("[INFO] acme: Trying to solve DNS-01") + + // Generate the Key Authorization for the challenge + keyAuth, err := getKeyAuthorization(chlng.Token, &s.jws.privKey.PublicKey) + if err != nil { + return err + } + + keyAuthShaBytes := sha256.Sum256([]byte(keyAuth)) + // base64URL encoding without padding + keyAuthSha := base64.URLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size]) + keyAuthSha = strings.TrimRight(keyAuthSha, "=") + + fqdn := fmt.Sprintf("_acme-challenge.%s.", domain) + if err = s.provider.CreateTXTRecord(fqdn, keyAuthSha, 120); err != nil { + return err + } + + if preCheckDNS() { + // check if the expected DNS entry was created. If not wait for some time and try again. + m := new(dns.Msg) + m.SetQuestion(domain+".", dns.TypeSOA) + c := new(dns.Client) + in, _, err := c.Exchange(m, "8.8.8.8:53") + if err != nil { + return err + } + + var authorativeNS string + for _, answ := range in.Answer { + soa := answ.(*dns.SOA) + authorativeNS = soa.Ns + } + + fallbackCnt := 0 + for fallbackCnt < preCheckDNSFallbackCount { + m.SetQuestion(fqdn, dns.TypeTXT) + in, _, err = c.Exchange(m, authorativeNS+":53") + if err != nil { + return err + } + + if len(in.Answer) > 0 { + break + } + + fallbackCnt++ + if fallbackCnt >= preCheckDNSFallbackCount { + return errors.New("Could not retrieve the value from DNS in a timely manner. Aborting.") + } + + time.Sleep(time.Second * time.Duration(fallbackCnt)) + } + } + + jsonBytes, err := json.Marshal(challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth}) + if err != nil { + return errors.New("Failed to marshal network message...") + } + + // Tell the server we handle DNS-01 + resp, err := s.jws.post(chlng.URI, jsonBytes) + if err != nil { + return fmt.Errorf("Failed to post JWS message. -> %v", err) + } + + // Repeatedly check the server for an updated status on our request. + var challengeResponse challenge +Loop: + for { + if resp.StatusCode >= http.StatusBadRequest { + return handleHTTPError(resp) + } + + err = json.NewDecoder(resp.Body).Decode(&challengeResponse) + resp.Body.Close() + if err != nil { + return err + } + + switch challengeResponse.Status { + case "valid": + logf("The server validated our request") + break Loop + case "pending": + break + case "invalid": + return errors.New("The server could not validate our request.") + default: + return errors.New("The server returned an unexpected state.") + } + + time.Sleep(1 * time.Second) + resp, err = http.Get(chlng.URI) + } + + if err = s.provider.RemoveTXTRecord(fqdn, keyAuthSha, 120); err != nil { + logf("[WARN] acme: Failed to cleanup DNS record. -> %v ", err) + } + + return nil +} diff --git a/acme/dns_challenge_cloudflare.go b/acme/dns_challenge_cloudflare.go new file mode 100644 index 00000000..da7c8bed --- /dev/null +++ b/acme/dns_challenge_cloudflare.go @@ -0,0 +1,158 @@ +package acme + +import ( + "fmt" + "os" + "strings" + + "github.com/crackcomm/cloudflare" + "golang.org/x/net/context" +) + +// DNSProviderCloudFlare is an implementation of the DNSProvider interface +type DNSProviderCloudFlare struct { + client *cloudflare.Client + ctx context.Context +} + +// NewDNSProviderCloudFlare returns a DNSProviderCloudFlare instance with a configured cloudflare client. +// Authentication is either done using the passed credentials or - when empty - using the environment +// variables CLOUDFLARE_EMAIL and CLOUDFLARE_API_KEY. +func NewDNSProviderCloudFlare(cloudflareEmail, cloudflareKey string) (*DNSProviderCloudFlare, error) { + if cloudflareEmail == "" || cloudflareKey == "" { + cloudflareEmail, cloudflareKey = envAuth() + if cloudflareEmail == "" || cloudflareKey == "" { + return nil, fmt.Errorf("CloudFlare credentials missing") + } + } + + c := &DNSProviderCloudFlare{ + client: cloudflare.New(&cloudflare.Options{cloudflareEmail, cloudflareKey}), + ctx: context.Background(), + } + + return c, nil +} + +// CreateTXTRecord creates a TXT record using the specified parameters +func (c *DNSProviderCloudFlare) CreateTXTRecord(fqdn, value string, ttl int) error { + zoneID, err := c.getHostedZoneID(fqdn) + if err != nil { + return err + } + + record := newTxtRecord(zoneID, fqdn, value, ttl) + err = c.client.Records.Create(c.ctx, record) + if err != nil { + return fmt.Errorf("CloudFlare API call failed: %v", err) + } + + return nil +} + +// RemoveTXTRecord removes the TXT record matching the specified parameters +func (c *DNSProviderCloudFlare) RemoveTXTRecord(fqdn, value string, ttl int) error { + records, err := c.findTxtRecords(fqdn) + if err != nil { + return err + } + + for _, rec := range records { + err := c.client.Records.Delete(c.ctx, rec.ZoneID, rec.ID) + if err != nil { + return fmt.Errorf("CloudFlare API call has failed: %v", err) + } + } + + return nil +} + +func (c *DNSProviderCloudFlare) findTxtRecords(fqdn string) ([]*cloudflare.Record, error) { + zoneID, err := c.getHostedZoneID(fqdn) + if err != nil { + return nil, err + } + + var records []*cloudflare.Record + result, err := c.client.Records.List(c.ctx, zoneID) + if err != nil { + return records, fmt.Errorf("CloudFlare API call has failed: %v", err) + } + + name := unFqdn(fqdn) + for _, rec := range result { + if rec.Name == name && rec.Type == "TXT" { + records = append(records, rec) + } + } + + return records, nil +} + +func (c *DNSProviderCloudFlare) getHostedZoneID(fqdn string) (string, error) { + zones, err := c.client.Zones.List(c.ctx) + if err != nil { + return "", fmt.Errorf("CloudFlare API call failed: %v", err) + } + + var hostedZone cloudflare.Zone + for _, zone := range zones { + name := toFqdn(zone.Name) + if strings.HasSuffix(fqdn, name) { + if len(zone.Name) > len(hostedZone.Name) { + hostedZone = *zone + } + } + } + if hostedZone.ID == "" { + return "", fmt.Errorf("No matching CloudFlare zone found for domain %s", fqdn) + } + + return hostedZone.ID, nil +} + +func newTxtRecord(zoneID, fqdn, value string, ttl int) *cloudflare.Record { + name := unFqdn(fqdn) + return &cloudflare.Record{ + Type: "TXT", + Name: name, + Content: value, + TTL: sanitizeTTL(ttl), + ZoneID: zoneID, + } +} + +func toFqdn(name string) string { + n := len(name) + if n == 0 || name[n-1] == '.' { + return name + } + return name + "." +} + +func unFqdn(name string) string { + n := len(name) + if n != 0 && name[n-1] == '.' { + return name[:n-1] + } + return name +} + +// TTL must be between 120 and 86400 seconds +func sanitizeTTL(ttl int) int { + if ttl < 120 { + ttl = 120 + } else if ttl > 86400 { + ttl = 86400 + } + return ttl +} + +func envAuth() (email, apiKey string) { + email = os.Getenv("CLOUDFLARE_EMAIL") + apiKey = os.Getenv("CLOUDFLARE_API_KEY") + if len(email) == 0 || len(apiKey) == 0 { + return "", "" + } + return +} diff --git a/acme/dns_challenge_cloudflare_test.go b/acme/dns_challenge_cloudflare_test.go new file mode 100644 index 00000000..bb17f028 --- /dev/null +++ b/acme/dns_challenge_cloudflare_test.go @@ -0,0 +1,83 @@ +package acme + +import ( + "fmt" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var ( + cflareLiveTest bool + cflareEmail string + cflareAPIKey string + cflareDomain string +) + +func init() { + cflareEmail = os.Getenv("CLOUDFLARE_EMAIL") + cflareAPIKey = os.Getenv("CLOUDFLARE_API_KEY") + cflareDomain = os.Getenv("CLOUDFLARE_DOMAIN") + if len(cflareEmail) > 0 && len(cflareAPIKey) > 0 && len(cflareDomain) > 0 { + cflareLiveTest = true + } +} + +func restoreCloudFlareEnv() { + os.Setenv("CLOUDFLARE_EMAIL", cflareEmail) + os.Setenv("CLOUDFLARE_API_KEY", cflareAPIKey) +} + +func TestNewDNSProviderCloudFlareValid(t *testing.T) { + os.Setenv("CLOUDFLARE_EMAIL", "") + os.Setenv("CLOUDFLARE_API_KEY", "") + _, err := NewDNSProviderCloudFlare("123", "123") + assert.NoError(t, err) + restoreCloudFlareEnv() +} + +func TestNewDNSProviderCloudFlareValidEnv(t *testing.T) { + os.Setenv("CLOUDFLARE_EMAIL", "test@example.com") + os.Setenv("CLOUDFLARE_API_KEY", "123") + _, err := NewDNSProviderCloudFlare("", "") + assert.NoError(t, err) + restoreCloudFlareEnv() +} + +func TestNewDNSProviderCloudFlareMissingCredErr(t *testing.T) { + os.Setenv("CLOUDFLARE_EMAIL", "") + os.Setenv("CLOUDFLARE_API_KEY", "") + _, err := NewDNSProviderCloudFlare("", "") + assert.EqualError(t, err, "CloudFlare credentials missing") + restoreCloudFlareEnv() +} + +func TestCloudFlareCreateTXTRecord(t *testing.T) { + if !cflareLiveTest { + t.Skip("skipping live test") + } + + provider, err := NewDNSProviderCloudFlare(cflareEmail, cflareAPIKey) + assert.NoError(t, err) + + fqdn := fmt.Sprintf("_acme-challenge.123.%s.", cflareDomain) + err = provider.CreateTXTRecord(fqdn, "123d==", 120) + assert.NoError(t, err) +} + +func TestCloudFlareRemoveTXTRecord(t *testing.T) { + if !cflareLiveTest { + t.Skip("skipping live test") + } + + time.Sleep(time.Second * 1) + + provider, err := NewDNSProviderCloudFlare(cflareEmail, cflareAPIKey) + assert.NoError(t, err) + + fqdn := fmt.Sprintf("_acme-challenge.123.%s.", cflareDomain) + err = provider.RemoveTXTRecord(fqdn, "123d==", 120) + assert.NoError(t, err) +} diff --git a/acme/dns_challenge_manual.go b/acme/dns_challenge_manual.go new file mode 100644 index 00000000..3f3805de --- /dev/null +++ b/acme/dns_challenge_manual.go @@ -0,0 +1,38 @@ +package acme + +import ( + "bufio" + "fmt" + "os" +) + +const ( + dnsTemplate = "%s %d IN TXT \"%s\"" +) + +// DNSProviderManual is an implementation of the DNSProvider interface +type DNSProviderManual struct{} + +// NewDNSProviderManual returns a DNSProviderManual instance. +func NewDNSProviderManual() (*DNSProviderManual, error) { + return &DNSProviderManual{}, nil +} + +// CreateTXTRecord prints instructions for manually creating the TXT record +func (*DNSProviderManual) CreateTXTRecord(fqdn, value string, ttl int) error { + dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, value) + logf("[INFO] acme: Please create the following TXT record in your DNS zone:") + logf("[INFO] acme: %s", dnsRecord) + logf("[INFO] acme: Press 'Enter' when you are done") + reader := bufio.NewReader(os.Stdin) + _, _ = reader.ReadString('\n') + return nil +} + +// RemoveTXTRecord prints instructions for manually removing the TXT record +func (*DNSProviderManual) RemoveTXTRecord(fqdn, value string, ttl int) error { + dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, value) + logf("[INFO] acme: You can now remove this TXT record from your DNS zone:") + logf("[INFO] acme: %s", dnsRecord) + return nil +} diff --git a/acme/dns_challenge_rfc2136.go b/acme/dns_challenge_rfc2136.go new file mode 100644 index 00000000..cf8412f8 --- /dev/null +++ b/acme/dns_challenge_rfc2136.go @@ -0,0 +1,84 @@ +package acme + +import ( + "fmt" + "github.com/miekg/dns" + "time" +) + +// DNSProviderRFC2136 is an implementation of the DNSProvider interface that +// uses dynamic DNS updates (RFC 2136) to create TXT records on a nameserver. +type DNSProviderRFC2136 struct { + nameserver string + zone string + tsigKey string + tsigSecret string +} + +// NewDNSProviderRFC2136 returns a new DNSProviderRFC2136 instance. +// To disable TSIG authentication 'tsigKey' and 'tsigSecret' must be set to the empty string. +// 'nameserver' must be a network address in the the form "host:port". 'zone' must be the fully +// qualified name of the zone. +func NewDNSProviderRFC2136(nameserver, zone, tsigKey, tsigSecret string) (*DNSProviderRFC2136, error) { + d := &DNSProviderRFC2136{ + nameserver: nameserver, + zone: zone, + } + if len(tsigKey) > 0 && len(tsigSecret) > 0 { + d.tsigKey = tsigKey + d.tsigSecret = tsigSecret + } + + return d, nil +} + +// CreateTXTRecord creates a TXT record using the specified parameters +func (r *DNSProviderRFC2136) CreateTXTRecord(fqdn, value string, ttl int) error { + return r.changeRecord("INSERT", fqdn, value, ttl) +} + +// RemoveTXTRecord removes the TXT record matching the specified parameters +func (r *DNSProviderRFC2136) RemoveTXTRecord(fqdn, value string, ttl int) error { + return r.changeRecord("REMOVE", fqdn, value, ttl) +} + +func (r *DNSProviderRFC2136) changeRecord(action, fqdn, value string, ttl int) error { + // Create RR + rr := new(dns.TXT) + rr.Hdr = dns.RR_Header{Name: fqdn, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: uint32(ttl)} + rr.Txt = []string{value} + rrs := make([]dns.RR, 1) + rrs[0] = rr + + // Create dynamic update packet + m := new(dns.Msg) + m.SetUpdate(dns.Fqdn(r.zone)) + switch action { + case "INSERT": + m.Insert(rrs) + case "REMOVE": + m.Remove(rrs) + default: + return fmt.Errorf("Unexpected action: %s", action) + } + + // Setup client + c := new(dns.Client) + c.SingleInflight = true + // TSIG authentication / msg signing + if len(r.tsigKey) > 0 && len(r.tsigSecret) > 0 { + m.SetTsig(dns.Fqdn(r.tsigKey), dns.HmacMD5, 300, time.Now().Unix()) + c.TsigSecret = map[string]string{dns.Fqdn(r.tsigKey): r.tsigSecret} + } + + // Send the query + reply, _, err := c.Exchange(m, r.nameserver) + if err != nil { + return fmt.Errorf("DNS update failed: %v", err) + } + if reply != nil && reply.Rcode != dns.RcodeSuccess { + return fmt.Errorf("DNS update failed. Server replied: %s", dns.RcodeToString[reply.Rcode]) + } + + return nil +} diff --git a/acme/dns_challenge_rfc2136_test.go b/acme/dns_challenge_rfc2136_test.go new file mode 100644 index 00000000..0b832fd3 --- /dev/null +++ b/acme/dns_challenge_rfc2136_test.go @@ -0,0 +1,227 @@ +package acme + +import ( + "bytes" + "github.com/miekg/dns" + "net" + "strings" + "sync" + "testing" + "time" +) + +var ( + rfc2136TestValue = "so6ZGir4GaZqI11h9UccBB==" + rfc2136TestFqdn = "_acme-challenge.123456789.www.example.com." + rfc2136TestZone = "example.com." + rfc2136TestTTL = 120 + rfc2136TestTsigKey = "example.com." + rfc2136TestTsigSecret = "IwBTJx9wrDp4Y1RyC3H0gA==" +) + +var reqChan = make(chan *dns.Msg, 10) + +func TestRFC2136CanaryLocalTestServer(t *testing.T) { + dns.HandleFunc("example.com.", serverHandlerHello) + defer dns.HandleRemove("example.com.") + + server, addrstr, err := runLocalDNSTestServer("127.0.0.1:0", false) + if err != nil { + t.Fatalf("Failed to start test server: %v", err) + } + defer server.Shutdown() + + c := new(dns.Client) + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeTXT) + r, _, err := c.Exchange(m, addrstr) + if err != nil || len(r.Extra) == 0 { + t.Fatalf("Failed to communicate with test server:", err) + } + txt := r.Extra[0].(*dns.TXT).Txt[0] + if txt != "Hello world" { + t.Error("Expected test server to return 'Hello world' but got: ", txt) + } +} + +func TestRFC2136ServerSuccess(t *testing.T) { + dns.HandleFunc(rfc2136TestZone, serverHandlerReturnSuccess) + defer dns.HandleRemove(rfc2136TestZone) + + server, addrstr, err := runLocalDNSTestServer("127.0.0.1:0", false) + if err != nil { + t.Fatalf("Failed to start test server: %v", err) + } + defer server.Shutdown() + + provider, err := NewDNSProviderRFC2136(addrstr, rfc2136TestZone, "", "") + if err != nil { + t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err) + } + if err := provider.CreateTXTRecord(rfc2136TestFqdn, rfc2136TestValue, rfc2136TestTTL); err != nil { + t.Errorf("Expected CreateTXTRecord() to return no error but the error was -> %v", err) + } +} + +func TestRFC2136ServerError(t *testing.T) { + dns.HandleFunc(rfc2136TestZone, serverHandlerReturnErr) + defer dns.HandleRemove(rfc2136TestZone) + + server, addrstr, err := runLocalDNSTestServer("127.0.0.1:0", false) + if err != nil { + t.Fatalf("Failed to start test server: %v", err) + } + defer server.Shutdown() + + provider, err := NewDNSProviderRFC2136(addrstr, rfc2136TestZone, "", "") + if err != nil { + t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err) + } + if err := provider.CreateTXTRecord(rfc2136TestFqdn, rfc2136TestValue, rfc2136TestTTL); err == nil { + t.Errorf("Expected CreateTXTRecord() to return an error but it did not.") + } else if !strings.Contains(err.Error(), "NOTZONE") { + t.Errorf("Expected CreateTXTRecord() to return an error with the 'NOTZONE' rcode string but it did not.") + } +} + +func TestRFC2136TsigClient(t *testing.T) { + dns.HandleFunc(rfc2136TestZone, serverHandlerReturnSuccess) + defer dns.HandleRemove(rfc2136TestZone) + + server, addrstr, err := runLocalDNSTestServer("127.0.0.1:0", true) + if err != nil { + t.Fatalf("Failed to start test server: %v", err) + } + defer server.Shutdown() + + provider, err := NewDNSProviderRFC2136(addrstr, rfc2136TestZone, rfc2136TestTsigKey, rfc2136TestTsigSecret) + if err != nil { + t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err) + } + if err := provider.CreateTXTRecord(rfc2136TestFqdn, rfc2136TestValue, rfc2136TestTTL); err != nil { + t.Errorf("Expected CreateTXTRecord() to return no error but the error was -> %v", err) + } +} + +func TestRFC2136ValidUpdatePacket(t *testing.T) { + dns.HandleFunc(rfc2136TestZone, serverHandlerPassBackRequest) + defer dns.HandleRemove(rfc2136TestZone) + + server, addrstr, err := runLocalDNSTestServer("127.0.0.1:0", false) + if err != nil { + t.Fatalf("Failed to start test server: %v", err) + } + defer server.Shutdown() + + rr := new(dns.TXT) + rr.Hdr = dns.RR_Header{ + Name: rfc2136TestFqdn, + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: uint32(rfc2136TestTTL), + } + rr.Txt = []string{rfc2136TestValue} + rrs := make([]dns.RR, 1) + rrs[0] = rr + m := new(dns.Msg) + m.SetUpdate(dns.Fqdn(rfc2136TestZone)) + m.Insert(rrs) + expectstr := m.String() + expect, err := m.Pack() + if err != nil { + t.Fatalf("Error packing expect msg: %v", err) + } + + provider, err := NewDNSProviderRFC2136(addrstr, rfc2136TestZone, "", "") + if err != nil { + t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err) + } + if err := provider.CreateTXTRecord(rfc2136TestFqdn, rfc2136TestValue, rfc2136TestTTL); err != nil { + t.Errorf("Expected CreateTXTRecord() to return no error but the error was -> %v", err) + } + + rcvMsg := <-reqChan + rcvMsg.Id = m.Id + actual, err := rcvMsg.Pack() + if err != nil { + t.Fatalf("Error packing actual msg: %v", err) + } + + if !bytes.Equal(actual, expect) { + tmp := new(dns.Msg) + if err := tmp.Unpack(actual); err != nil { + t.Fatalf("Error unpacking actual msg: %v", err) + } + t.Errorf("Expected msg:\n%s", expectstr) + t.Errorf("Actual msg:\n%v", tmp) + } +} + +func runLocalDNSTestServer(listenAddr string, tsig bool) (*dns.Server, string, error) { + pc, err := net.ListenPacket("udp", listenAddr) + if err != nil { + return nil, "", err + } + server := &dns.Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour} + if tsig { + server.TsigSecret = map[string]string{rfc2136TestTsigKey: rfc2136TestTsigSecret} + } + + waitLock := sync.Mutex{} + waitLock.Lock() + server.NotifyStartedFunc = waitLock.Unlock + + go func() { + server.ActivateAndServe() + pc.Close() + }() + + waitLock.Lock() + return server, pc.LocalAddr().String(), nil +} + +func serverHandlerHello(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + m.Extra = make([]dns.RR, 1) + m.Extra[0] = &dns.TXT{ + Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 0}, + Txt: []string{"Hello world"}, + } + w.WriteMsg(m) +} + +func serverHandlerReturnSuccess(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if t := req.IsTsig(); t != nil { + if w.TsigStatus() == nil { + // Validated + m.SetTsig(rfc2136TestZone, dns.HmacMD5, 300, time.Now().Unix()) + } + } + + w.WriteMsg(m) +} + +func serverHandlerReturnErr(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetRcode(req, dns.RcodeNotZone) + w.WriteMsg(m) +} + +func serverHandlerPassBackRequest(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + + if t := req.IsTsig(); t != nil { + if w.TsigStatus() == nil { + // Validated + m.SetTsig(rfc2136TestZone, dns.HmacMD5, 300, time.Now().Unix()) + } + } + + w.WriteMsg(m) + reqChan <- req +} diff --git a/acme/dns_challenge_route53.go b/acme/dns_challenge_route53.go new file mode 100644 index 00000000..764b93f6 --- /dev/null +++ b/acme/dns_challenge_route53.go @@ -0,0 +1,93 @@ +package acme + +import ( + "fmt" + "github.com/mitchellh/goamz/aws" + "github.com/mitchellh/goamz/route53" + "math" + "strings" +) + +// DNSProviderRoute53 is an implementation of the DNSProvider interface +type DNSProviderRoute53 struct { + client *route53.Route53 +} + +// NewDNSProviderRoute53 returns a DNSProviderRoute53 instance with a configured route53 client. +// Authentication is either done using the passed credentials or - when empty - +// using the environment variables AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY. +func NewDNSProviderRoute53(awsAccessKey, awsSecretKey, awsRegionName string) (*DNSProviderRoute53, error) { + region, ok := aws.Regions[awsRegionName] + if !ok { + return nil, fmt.Errorf("Invalid AWS region name %s", awsRegionName) + } + + var auth aws.Auth + // First try passed in credentials + if awsAccessKey != "" && awsSecretKey != "" { + auth = aws.Auth{awsAccessKey, awsSecretKey, ""} + } else { + // try getting credentials from environment + envAuth, err := aws.EnvAuth() + if err != nil { + return nil, fmt.Errorf("AWS credentials missing") + } + auth = envAuth + } + + client := route53.New(auth, region) + return &DNSProviderRoute53{client: client}, nil +} + +// CreateTXTRecord creates a TXT record using the specified parameters +func (r *DNSProviderRoute53) CreateTXTRecord(fqdn, value string, ttl int) error { + return r.changeRecord("UPSERT", fqdn, value, ttl) +} + +// RemoveTXTRecord removes the TXT record matching the specified parameters +func (r *DNSProviderRoute53) RemoveTXTRecord(fqdn, value string, ttl int) error { + return r.changeRecord("DELETE", fqdn, value, ttl) +} + +func (r *DNSProviderRoute53) changeRecord(action, fqdn, value string, ttl int) error { + hostedZoneID, err := r.getHostedZoneID(fqdn) + if err != nil { + return err + } + recordSet := newTXTRecordSet(fqdn, value, ttl) + update := route53.Change{action, recordSet} + changes := []route53.Change{update} + req := route53.ChangeResourceRecordSetsRequest{Comment: "Created by Lego", Changes: changes} + _, err = r.client.ChangeResourceRecordSets(hostedZoneID, &req) + return err +} + +func (r *DNSProviderRoute53) getHostedZoneID(fqdn string) (string, error) { + zoneResp, err := r.client.ListHostedZones("", math.MaxInt64) + if err != nil { + return "", err + } + var hostedZone route53.HostedZone + for _, zone := range zoneResp.HostedZones { + //if strings.HasSuffix(domain, strings.Trim(zone.Name, ".")) { + if strings.HasSuffix(fqdn, zone.Name) { + if len(zone.Name) > len(hostedZone.Name) { + hostedZone = zone + } + } + } + if hostedZone.ID == "" { + return "", fmt.Errorf("No Route53 zone found for domain %s", fqdn) + } + + return hostedZone.ID, nil +} + +func newTXTRecordSet(fqdn, value string, ttl int) route53.ResourceRecordSet { + return route53.ResourceRecordSet{ + Name: fqdn, + Type: "TXT", + Records: []string{value}, + TTL: ttl, + } +} diff --git a/acme/dns_challenge_route53_test.go b/acme/dns_challenge_route53_test.go new file mode 100644 index 00000000..c377b84e --- /dev/null +++ b/acme/dns_challenge_route53_test.go @@ -0,0 +1,126 @@ +package acme + +import ( + "os" + "testing" + + "github.com/mitchellh/goamz/aws" + "github.com/mitchellh/goamz/route53" + "github.com/mitchellh/goamz/testutil" + "github.com/stretchr/testify/assert" +) + +var ( + route53Secret string + route53Key string + testServer *testutil.HTTPServer +) + +var ChangeResourceRecordSetsAnswer = ` + + + /change/asdf + PENDING + 2014 + +` + +var ListHostedZonesAnswer = ` + + + + /hostedzone/Z2K123214213123 + example.com. + D2224C5B-684A-DB4A-BB9A-E09E3BAFEA7A + + Test comment + + 10 + + + /hostedzone/ZLT12321321124 + sub.example.com. + A970F076-FCB1-D959-B395-96474CC84EB8 + + Test comment for subdomain host + + 4 + + + false + 100 +` + +var serverResponseMap = testutil.ResponseMap{ + "/2013-04-01/hostedzone/": testutil.Response{200, nil, ListHostedZonesAnswer}, + "/2013-04-01/hostedzone/Z2K123214213123/rrset": testutil.Response{200, nil, ChangeResourceRecordSetsAnswer}, +} + +func init() { + route53Key = os.Getenv("AWS_ACCESS_KEY_ID") + route53Secret = os.Getenv("AWS_SECRET_ACCESS_KEY") + testServer = testutil.NewHTTPServer() + testServer.Start() +} + +func restoreRoute53Env() { + os.Setenv("AWS_ACCESS_KEY_ID", route53Key) + os.Setenv("AWS_SECRET_ACCESS_KEY", route53Secret) +} + +func makeRoute53TestServer() *testutil.HTTPServer { + testServer.Flush() + return testServer +} + +func makeRoute53Provider(server *testutil.HTTPServer) *DNSProviderRoute53 { + auth := aws.Auth{"abc", "123", ""} + client := route53.NewWithClient(auth, aws.Region{Route53Endpoint: server.URL}, testutil.DefaultClient) + return &DNSProviderRoute53{client: client} +} + +func TestNewDNSProviderRoute53Valid(t *testing.T) { + os.Setenv("AWS_ACCESS_KEY_ID", "") + os.Setenv("AWS_SECRET_ACCESS_KEY", "") + _, err := NewDNSProviderRoute53("123", "123", "us-east-1") + assert.NoError(t, err) + restoreRoute53Env() +} + +func TestNewDNSProviderRoute53ValidEnv(t *testing.T) { + os.Setenv("AWS_ACCESS_KEY_ID", "123") + os.Setenv("AWS_SECRET_ACCESS_KEY", "123") + _, err := NewDNSProviderRoute53("", "", "us-east-1") + assert.NoError(t, err) + restoreRoute53Env() +} + +func TestNewDNSProviderRoute53MissingAuthErr(t *testing.T) { + os.Setenv("AWS_ACCESS_KEY_ID", "") + os.Setenv("AWS_SECRET_ACCESS_KEY", "") + _, err := NewDNSProviderRoute53("", "", "us-east-1") + assert.EqualError(t, err, "AWS credentials missing") + restoreRoute53Env() +} + +func TestNewDNSProviderRoute53InvalidRegionErr(t *testing.T) { + _, err := NewDNSProviderRoute53("123", "123", "us-east-3") + assert.EqualError(t, err, "Invalid AWS region name us-east-3") +} + +func TestRoute53CreateTXTRecord(t *testing.T) { + assert := assert.New(t) + testServer := makeRoute53TestServer() + provider := makeRoute53Provider(testServer) + testServer.ResponseMap(2, serverResponseMap) + + err := provider.CreateTXTRecord("_acme-challenge.123.example.com.", "123456d==", 120) + assert.NoError(err, "Expected CreateTXTRecord to return no error") + + httpReqs := testServer.WaitRequests(2) + httpReq := httpReqs[1] + + assert.Equal("/2013-04-01/hostedzone/Z2K123214213123/rrset", httpReq.URL.Path, + "Expected CreateTXTRecord to select the correct hostedzone") + +} diff --git a/acme/dns_challenge_test.go b/acme/dns_challenge_test.go new file mode 100644 index 00000000..e69d7c4b --- /dev/null +++ b/acme/dns_challenge_test.go @@ -0,0 +1,39 @@ +package acme + +import ( + "bufio" + "crypto/rsa" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" +) + +func TestDNSValidServerResponse(t *testing.T) { + preCheckDNS = func() bool { + return false + } + privKey, _ := generatePrivateKey(rsakey, 512) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Replay-Nonce", "12345") + w.Write([]byte("{\"type\":\"dns01\",\"status\":\"valid\",\"uri\":\"http://some.url\",\"token\":\"http8\"}")) + })) + + manualProvider, _ := NewDNSProviderManual() + jws := &jws{privKey: privKey.(*rsa.PrivateKey), directoryURL: ts.URL} + solver := &dnsChallenge{jws: jws, provider: manualProvider} + clientChallenge := challenge{Type: "dns01", Status: "pending", URI: ts.URL, Token: "http8"} + + go func() { + time.Sleep(time.Second * 2) + f := bufio.NewWriter(os.Stdout) + defer f.Flush() + f.WriteString("\n") + }() + + if err := solver.Solve(clientChallenge, "example.com"); err != nil { + t.Errorf("VALID: Expected Solve to return no error but the error was -> %v", err) + } +}