Refactor challenge providers to new ChallengeProvider interface

* new ChallengeProvider with Present and CleanUp methods
* new Challenge type describing `http-01`, `tls-sni-01`, `dns-01`
* new client.SetChallengeProvider to support custom implementations
This commit is contained in:
Jehiah Czebotar 2016-01-14 23:06:25 -05:00
parent 2e5ae296cc
commit 617dd4d37c
20 changed files with 317 additions and 185 deletions

15
acme/challenges.go Normal file
View file

@ -0,0 +1,15 @@
package acme
type Challenge string
const (
// HTTP01 is the "http-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#http
// Note: HTTP01ChallengePath returns the URL path to fulfill this challenge
HTTP01 = Challenge("http-01")
// TLSSNI01 is the "tls-sni-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#tls-with-server-name-indication-tls-sni
// Note: TLSSNI01ChallengeCert returns a certificate to fulfill this challenge
TLSSNI01 = Challenge("tls-sni-01")
// DNS01 is the "dns-01" ACME challenge https://github.com/ietf-wg-acme/acme/blob/master/draft-ietf-acme-acme.md#dns
// Note: DNS01Record returns a DNS record which will fulfill this challenge
DNS01 = Challenge("dns-01")
)

View file

@ -54,7 +54,7 @@ type Client struct {
jws *jws jws *jws
keyBits int keyBits int
issuerCert []byte issuerCert []byte
solvers map[string]solver solvers map[Challenge]solver
} }
// NewClient creates a new ACME client on behalf of the user. The client will depend on // NewClient creates a new ACME client on behalf of the user. The client will depend on
@ -93,13 +93,28 @@ func NewClient(caDirURL string, user User, keyBits int) (*Client, error) {
// REVIEW: best possibility? // REVIEW: best possibility?
// Add all available solvers with the right index as per ACME // Add all available solvers with the right index as per ACME
// spec to this map. Otherwise they won`t be found. // spec to this map. Otherwise they won`t be found.
solvers := make(map[string]solver) solvers := make(map[Challenge]solver)
solvers["http-01"] = &httpChallenge{jws: jws, validate: validate} solvers[HTTP01] = &httpChallenge{jws: jws, validate: validate}
solvers["tls-sni-01"] = &tlsSNIChallenge{jws: jws, validate: validate} solvers[TLSSNI01] = &tlsSNIChallenge{jws: jws, validate: validate}
return &Client{directory: dir, user: user, jws: jws, keyBits: keyBits, solvers: solvers}, nil return &Client{directory: dir, user: user, jws: jws, keyBits: keyBits, solvers: solvers}, nil
} }
// SetChallengeProvider specifies a custom provider that will make the solution available
func (c *Client) SetChallengeProvider(challenge Challenge, p ChallengeProvider) error {
switch challenge {
case HTTP01:
c.solvers[challenge] = &httpChallenge{jws: c.jws, validate: validate, provider: p}
case TLSSNI01:
c.solvers[challenge] = &tlsSNIChallenge{jws: c.jws, validate: validate, provider: p}
case DNS01:
c.solvers[challenge] = &dnsChallenge{jws: c.jws, provider: p}
default:
return fmt.Errorf("Unknown challenge %v", challenge)
}
return nil
}
// SetHTTPAddress specifies a custom interface:port to be used for HTTP based challenges. // SetHTTPAddress specifies a custom interface:port to be used for HTTP based challenges.
// If this option is not used, the default port 80 and all interfaces will be used. // If this option is not used, the default port 80 and all interfaces will be used.
// To only specify a port and no interface use the ":port" notation. // To only specify a port and no interface use the ":port" notation.
@ -109,9 +124,8 @@ func (c *Client) SetHTTPAddress(iface string) error {
return err return err
} }
if chlng, ok := c.solvers["http-01"]; ok { if chlng, ok := c.solvers[HTTP01]; ok {
chlng.(*httpChallenge).iface = host chlng.(*httpChallenge).provider = &httpChallengeServer{iface: host, port: port}
chlng.(*httpChallenge).port = port
} }
return nil return nil
@ -126,23 +140,19 @@ func (c *Client) SetTLSAddress(iface string) error {
return err return err
} }
if chlng, ok := c.solvers["tls-sni-01"]; ok { if chlng, ok := c.solvers[TLSSNI01]; ok {
chlng.(*tlsSNIChallenge).iface = host chlng.(*tlsSNIChallenge).provider = &tlsSNIChallengeServer{iface: host, port: port}
chlng.(*tlsSNIChallenge).port = port
} }
return nil return nil
} }
// ExcludeChallenges explicitly removes challenges from the pool for solving. // ExcludeChallenges explicitly removes challenges from the pool for solving.
func (c *Client) ExcludeChallenges(challenges []string) { func (c *Client) ExcludeChallenges(challenges []Challenge) {
// Loop through all challenges and delete the requested one if found. // Loop through all challenges and delete the requested one if found.
for _, challenge := range challenges { for _, challenge := range challenges {
if _, ok := c.solvers[challenge]; ok {
delete(c.solvers, challenge) delete(c.solvers, challenge)
} }
} }
}
// Register the current account to the ACME server. // Register the current account to the ACME server.
func (c *Client) Register() (*RegistrationResource, error) { func (c *Client) Register() (*RegistrationResource, error) {

View file

@ -75,32 +75,32 @@ func TestClientOptPort(t *testing.T) {
client.SetHTTPAddress(net.JoinHostPort(optHost, optPort)) client.SetHTTPAddress(net.JoinHostPort(optHost, optPort))
client.SetTLSAddress(net.JoinHostPort(optHost, optPort)) client.SetTLSAddress(net.JoinHostPort(optHost, optPort))
httpSolver, ok := client.solvers["http-01"].(*httpChallenge) httpSolver, ok := client.solvers[HTTP01].(*httpChallenge)
if !ok { if !ok {
t.Fatal("Expected http-01 solver to be httpChallenge type") t.Fatal("Expected http-01 solver to be httpChallenge type")
} }
if httpSolver.jws != client.jws { if httpSolver.jws != client.jws {
t.Error("Expected http-01 to have same jws as client") t.Error("Expected http-01 to have same jws as client")
} }
if httpSolver.port != optPort { if got := httpSolver.provider.(*httpChallengeServer).port; got != optPort {
t.Errorf("Expected http-01 to have port %s but was %s", optPort, httpSolver.port) t.Errorf("Expected http-01 to have port %s but was %s", optPort, got)
} }
if httpSolver.iface != optHost { if got := httpSolver.provider.(*httpChallengeServer).iface; got != optHost {
t.Errorf("Expected http-01 to have iface %s but was %s", optHost, httpSolver.iface) t.Errorf("Expected http-01 to have iface %s but was %s", optHost, got)
} }
httpsSolver, ok := client.solvers["tls-sni-01"].(*tlsSNIChallenge) httpsSolver, ok := client.solvers[TLSSNI01].(*tlsSNIChallenge)
if !ok { if !ok {
t.Fatal("Expected tls-sni-01 solver to be httpChallenge type") t.Fatal("Expected tls-sni-01 solver to be httpChallenge type")
} }
if httpsSolver.jws != client.jws { if httpsSolver.jws != client.jws {
t.Error("Expected tls-sni-01 to have same jws as client") t.Error("Expected tls-sni-01 to have same jws as client")
} }
if httpsSolver.port != optPort { if got := httpsSolver.provider.(*tlsSNIChallengeServer).port; got != optPort {
t.Errorf("Expected tls-sni-01 to have port %s but was %s", optPort, httpSolver.port) t.Errorf("Expected tls-sni-01 to have port %s but was %s", optPort, got)
} }
if httpsSolver.port != optPort { if got := httpsSolver.provider.(*tlsSNIChallengeServer).iface; got != optHost {
t.Errorf("Expected tls-sni-01 to have port %s but was %s", optHost, httpSolver.iface) t.Errorf("Expected tls-sni-01 to have port %s but was %s", optHost, got)
} }
// test setting different host // test setting different host
@ -108,11 +108,11 @@ func TestClientOptPort(t *testing.T) {
client.SetHTTPAddress(net.JoinHostPort(optHost, optPort)) client.SetHTTPAddress(net.JoinHostPort(optHost, optPort))
client.SetTLSAddress(net.JoinHostPort(optHost, optPort)) client.SetTLSAddress(net.JoinHostPort(optHost, optPort))
if httpSolver.iface != optHost { if got := httpSolver.provider.(*httpChallengeServer).iface; got != optHost {
t.Errorf("Expected http-01 to have iface %s but was %s", optHost, httpSolver.iface) t.Errorf("Expected http-01 to have iface %s but was %s", optHost, got)
} }
if httpsSolver.port != optPort { if got := httpsSolver.provider.(*tlsSNIChallengeServer).port; got != optPort {
t.Errorf("Expected tls-sni-01 to have port %s but was %s", optHost, httpSolver.iface) t.Errorf("Expected tls-sni-01 to have port %s but was %s", optPort, got)
} }
} }

View file

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log"
"net/http" "net/http"
"strings" "strings"
"time" "time"
@ -19,37 +20,49 @@ var preCheckDNS preCheckDNSFunc = checkDNS
var preCheckDNSFallbackCount = 5 var preCheckDNSFallbackCount = 5
// DNSProvider represents a service for managing DNS records. // DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
type DNSProvider interface { func DNS01Record(domain, keyAuth string) (fqdn string, value string, ttl int) {
CreateTXTRecord(fqdn, value string, ttl int) error keyAuthShaBytes := sha256.Sum256([]byte(keyAuth))
RemoveTXTRecord(fqdn, value string, ttl int) error // base64URL encoding without padding
keyAuthSha := base64.URLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size])
value = strings.TrimRight(keyAuthSha, "=")
ttl = 120
fqdn = fmt.Sprintf("_acme-challenge.%s.", domain)
return
} }
// dnsChallenge implements the dns-01 challenge according to ACME 7.5 // dnsChallenge implements the dns-01 challenge according to ACME 7.5
type dnsChallenge struct { type dnsChallenge struct {
jws *jws jws *jws
provider DNSProvider provider ChallengeProvider
} }
func (s *dnsChallenge) Solve(chlng challenge, domain string) error { func (s *dnsChallenge) Solve(chlng challenge, domain string) error {
logf("[INFO] acme: Trying to solve DNS-01") logf("[INFO] acme: Trying to solve DNS-01")
if s.provider == nil {
return errors.New("No DNS Provider configured")
}
// Generate the Key Authorization for the challenge // Generate the Key Authorization for the challenge
keyAuth, err := getKeyAuthorization(chlng.Token, &s.jws.privKey.PublicKey) keyAuth, err := getKeyAuthorization(chlng.Token, &s.jws.privKey.PublicKey)
if err != nil { if err != nil {
return err return err
} }
keyAuthShaBytes := sha256.Sum256([]byte(keyAuth)) err = s.provider.Present(domain, chlng.Token, keyAuth)
// base64URL encoding without padding if err != nil {
keyAuthSha := base64.URLEncoding.EncodeToString(keyAuthShaBytes[:sha256.Size]) return fmt.Errorf("Error presenting token %s", err)
keyAuthSha = strings.TrimRight(keyAuthSha, "=")
fqdn := fmt.Sprintf("_acme-challenge.%s.", domain)
if err = s.provider.CreateTXTRecord(fqdn, keyAuthSha, 120); err != nil {
return err
} }
defer func() {
err := s.provider.CleanUp(domain, chlng.Token, keyAuth)
if err != nil {
log.Printf("Error cleaning up %s %v ", domain, err)
}
}()
fqdn, _, _ := DNS01Record(domain, keyAuth)
preCheckDNS(domain, fqdn) preCheckDNS(domain, fqdn)
@ -94,10 +107,6 @@ Loop:
resp, err = http.Get(chlng.URI) resp, err = http.Get(chlng.URI)
} }
if err = s.provider.RemoveTXTRecord(fqdn, keyAuthSha, 120); err != nil {
logf("[WARN] acme: Failed to cleanup DNS record. -> %v ", err)
}
return nil return nil
} }

View file

@ -34,8 +34,9 @@ func NewDNSProviderCloudFlare(cloudflareEmail, cloudflareKey string) (*DNSProvid
return c, nil return c, nil
} }
// CreateTXTRecord creates a TXT record using the specified parameters // Present creates a TXT record to fulfil the dns-01 challenge
func (c *DNSProviderCloudFlare) CreateTXTRecord(fqdn, value string, ttl int) error { func (c *DNSProviderCloudFlare) Present(domain, token, keyAuth string) error {
fqdn, value, ttl := DNS01Record(domain, keyAuth)
zoneID, err := c.getHostedZoneID(fqdn) zoneID, err := c.getHostedZoneID(fqdn)
if err != nil { if err != nil {
return err return err
@ -50,8 +51,9 @@ func (c *DNSProviderCloudFlare) CreateTXTRecord(fqdn, value string, ttl int) err
return nil return nil
} }
// RemoveTXTRecord removes the TXT record matching the specified parameters // CleanUp removes the TXT record matching the specified parameters
func (c *DNSProviderCloudFlare) RemoveTXTRecord(fqdn, value string, ttl int) error { func (c *DNSProviderCloudFlare) CleanUp(domain, token, keyAuth string) error {
fqdn, _, _ := DNS01Record(domain, keyAuth)
records, err := c.findTxtRecords(fqdn) records, err := c.findTxtRecords(fqdn)
if err != nil { if err != nil {
return err return err
@ -60,10 +62,9 @@ func (c *DNSProviderCloudFlare) RemoveTXTRecord(fqdn, value string, ttl int) err
for _, rec := range records { for _, rec := range records {
err := c.client.Records.Delete(c.ctx, rec.ZoneID, rec.ID) err := c.client.Records.Delete(c.ctx, rec.ZoneID, rec.ID)
if err != nil { if err != nil {
return fmt.Errorf("CloudFlare API call has failed: %v", err) return err
} }
} }
return nil return nil
} }
@ -140,13 +141,15 @@ func unFqdn(name string) string {
// TTL must be between 120 and 86400 seconds // TTL must be between 120 and 86400 seconds
func sanitizeTTL(ttl int) int { func sanitizeTTL(ttl int) int {
if ttl < 120 { switch {
ttl = 120 case ttl < 120:
} else if ttl > 86400 { return 120
ttl = 86400 case ttl > 86400:
} return 86400
default:
return ttl return ttl
} }
}
func envAuth() (email, apiKey string) { func envAuth() (email, apiKey string) {
email = os.Getenv("CLOUDFLARE_EMAIL") email = os.Getenv("CLOUDFLARE_EMAIL")

View file

@ -1,7 +1,6 @@
package acme package acme
import ( import (
"fmt"
"os" "os"
"testing" "testing"
"time" "time"
@ -54,7 +53,7 @@ func TestNewDNSProviderCloudFlareMissingCredErr(t *testing.T) {
restoreCloudFlareEnv() restoreCloudFlareEnv()
} }
func TestCloudFlareCreateTXTRecord(t *testing.T) { func TestCloudFlarePresent(t *testing.T) {
if !cflareLiveTest { if !cflareLiveTest {
t.Skip("skipping live test") t.Skip("skipping live test")
} }
@ -62,12 +61,11 @@ func TestCloudFlareCreateTXTRecord(t *testing.T) {
provider, err := NewDNSProviderCloudFlare(cflareEmail, cflareAPIKey) provider, err := NewDNSProviderCloudFlare(cflareEmail, cflareAPIKey)
assert.NoError(t, err) assert.NoError(t, err)
fqdn := fmt.Sprintf("_acme-challenge.123.%s.", cflareDomain) err = provider.Present(cflareDomain, "", "123d==")
err = provider.CreateTXTRecord(fqdn, "123d==", 120)
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestCloudFlareRemoveTXTRecord(t *testing.T) { func TestCloudFlareCleanUp(t *testing.T) {
if !cflareLiveTest { if !cflareLiveTest {
t.Skip("skipping live test") t.Skip("skipping live test")
} }
@ -77,7 +75,6 @@ func TestCloudFlareRemoveTXTRecord(t *testing.T) {
provider, err := NewDNSProviderCloudFlare(cflareEmail, cflareAPIKey) provider, err := NewDNSProviderCloudFlare(cflareEmail, cflareAPIKey)
assert.NoError(t, err) assert.NoError(t, err)
fqdn := fmt.Sprintf("_acme-challenge.123.%s.", cflareDomain) err = provider.CleanUp(cflareDomain, "", "123d==")
err = provider.RemoveTXTRecord(fqdn, "123d==", 120)
assert.NoError(t, err) assert.NoError(t, err)
} }

View file

@ -10,7 +10,7 @@ const (
dnsTemplate = "%s %d IN TXT \"%s\"" dnsTemplate = "%s %d IN TXT \"%s\""
) )
// DNSProviderManual is an implementation of the DNSProvider interface // DNSProviderManual is an implementation of the ChallengeProvider interface
type DNSProviderManual struct{} type DNSProviderManual struct{}
// NewDNSProviderManual returns a DNSProviderManual instance. // NewDNSProviderManual returns a DNSProviderManual instance.
@ -18,8 +18,9 @@ func NewDNSProviderManual() (*DNSProviderManual, error) {
return &DNSProviderManual{}, nil return &DNSProviderManual{}, nil
} }
// CreateTXTRecord prints instructions for manually creating the TXT record // Present prints instructions for manually creating the TXT record
func (*DNSProviderManual) CreateTXTRecord(fqdn, value string, ttl int) error { func (*DNSProviderManual) Present(domain, token, keyAuth string) error {
fqdn, value, ttl := DNS01Record(domain, keyAuth)
dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, value) dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, value)
logf("[INFO] acme: Please create the following TXT record in your DNS zone:") logf("[INFO] acme: Please create the following TXT record in your DNS zone:")
logf("[INFO] acme: %s", dnsRecord) logf("[INFO] acme: %s", dnsRecord)
@ -29,9 +30,10 @@ func (*DNSProviderManual) CreateTXTRecord(fqdn, value string, ttl int) error {
return nil return nil
} }
// RemoveTXTRecord prints instructions for manually removing the TXT record // CleanUp prints instructions for manually removing the TXT record
func (*DNSProviderManual) RemoveTXTRecord(fqdn, value string, ttl int) error { func (*DNSProviderManual) CleanUp(domain, token, keyAuth string) error {
dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, value) fqdn, _, ttl := DNS01Record(domain, keyAuth)
dnsRecord := fmt.Sprintf(dnsTemplate, fqdn, ttl, "...")
logf("[INFO] acme: You can now remove this TXT record from your DNS zone:") logf("[INFO] acme: You can now remove this TXT record from your DNS zone:")
logf("[INFO] acme: %s", dnsRecord) logf("[INFO] acme: %s", dnsRecord)
return nil return nil

View file

@ -6,13 +6,14 @@ import (
"time" "time"
) )
// DNSProviderRFC2136 is an implementation of the DNSProvider interface that // DNSProviderRFC2136 is an implementation of the ChallengeProvider interface that
// uses dynamic DNS updates (RFC 2136) to create TXT records on a nameserver. // uses dynamic DNS updates (RFC 2136) to create TXT records on a nameserver.
type DNSProviderRFC2136 struct { type DNSProviderRFC2136 struct {
nameserver string nameserver string
zone string zone string
tsigKey string tsigKey string
tsigSecret string tsigSecret string
records map[string]string
} }
// NewDNSProviderRFC2136 returns a new DNSProviderRFC2136 instance. // NewDNSProviderRFC2136 returns a new DNSProviderRFC2136 instance.
@ -23,6 +24,7 @@ func NewDNSProviderRFC2136(nameserver, zone, tsigKey, tsigSecret string) (*DNSPr
d := &DNSProviderRFC2136{ d := &DNSProviderRFC2136{
nameserver: nameserver, nameserver: nameserver,
zone: zone, zone: zone,
records: make(map[string]string),
} }
if len(tsigKey) > 0 && len(tsigSecret) > 0 { if len(tsigKey) > 0 && len(tsigSecret) > 0 {
d.tsigKey = tsigKey d.tsigKey = tsigKey
@ -32,13 +34,17 @@ func NewDNSProviderRFC2136(nameserver, zone, tsigKey, tsigSecret string) (*DNSPr
return d, nil return d, nil
} }
// CreateTXTRecord creates a TXT record using the specified parameters // Present creates a TXT record using the specified parameters
func (r *DNSProviderRFC2136) CreateTXTRecord(fqdn, value string, ttl int) error { func (r *DNSProviderRFC2136) Present(domain, token, keyAuth string) error {
fqdn, value, ttl := DNS01Record(domain, keyAuth)
r.records[fqdn] = value
return r.changeRecord("INSERT", fqdn, value, ttl) return r.changeRecord("INSERT", fqdn, value, ttl)
} }
// RemoveTXTRecord removes the TXT record matching the specified parameters // CleanUp removes the TXT record matching the specified parameters
func (r *DNSProviderRFC2136) RemoveTXTRecord(fqdn, value string, ttl int) error { func (r *DNSProviderRFC2136) CleanUp(domain, token, keyAuth string) error {
fqdn, _, ttl := DNS01Record(domain, keyAuth)
value := r.records[fqdn]
return r.changeRecord("REMOVE", fqdn, value, ttl) return r.changeRecord("REMOVE", fqdn, value, ttl)
} }

View file

@ -11,7 +11,9 @@ import (
) )
var ( var (
rfc2136TestValue = "so6ZGir4GaZqI11h9UccBB==" rfc2136TestDomain = "123456789.www.example.com"
rfc2136TestKeyAuth = "123d=="
rfc2136TestValue = "Now36o-3BmlB623-0c1qCIUmgWVVmDJb88KGl24pqpo"
rfc2136TestFqdn = "_acme-challenge.123456789.www.example.com." rfc2136TestFqdn = "_acme-challenge.123456789.www.example.com."
rfc2136TestZone = "example.com." rfc2136TestZone = "example.com."
rfc2136TestTTL = 120 rfc2136TestTTL = 120
@ -58,8 +60,8 @@ func TestRFC2136ServerSuccess(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err) t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err)
} }
if err := provider.CreateTXTRecord(rfc2136TestFqdn, rfc2136TestValue, rfc2136TestTTL); err != nil { if err := provider.Present(rfc2136TestDomain, "", rfc2136TestKeyAuth); err != nil {
t.Errorf("Expected CreateTXTRecord() to return no error but the error was -> %v", err) t.Errorf("Expected Present() to return no error but the error was -> %v", err)
} }
} }
@ -77,10 +79,10 @@ func TestRFC2136ServerError(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err) t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err)
} }
if err := provider.CreateTXTRecord(rfc2136TestFqdn, rfc2136TestValue, rfc2136TestTTL); err == nil { if err := provider.Present(rfc2136TestDomain, "", rfc2136TestKeyAuth); err == nil {
t.Errorf("Expected CreateTXTRecord() to return an error but it did not.") t.Errorf("Expected Present() to return an error but it did not.")
} else if !strings.Contains(err.Error(), "NOTZONE") { } else if !strings.Contains(err.Error(), "NOTZONE") {
t.Errorf("Expected CreateTXTRecord() to return an error with the 'NOTZONE' rcode string but it did not.") t.Errorf("Expected Present() to return an error with the 'NOTZONE' rcode string but it did not.")
} }
} }
@ -98,8 +100,8 @@ func TestRFC2136TsigClient(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err) t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err)
} }
if err := provider.CreateTXTRecord(rfc2136TestFqdn, rfc2136TestValue, rfc2136TestTTL); err != nil { if err := provider.Present(rfc2136TestDomain, "", rfc2136TestKeyAuth); err != nil {
t.Errorf("Expected CreateTXTRecord() to return no error but the error was -> %v", err) t.Errorf("Expected Present() to return no error but the error was -> %v", err)
} }
} }
@ -136,8 +138,9 @@ func TestRFC2136ValidUpdatePacket(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err) t.Fatalf("Expected NewDNSProviderRFC2136() to return no error but the error was -> %v", err)
} }
if err := provider.CreateTXTRecord(rfc2136TestFqdn, rfc2136TestValue, rfc2136TestTTL); err != nil {
t.Errorf("Expected CreateTXTRecord() to return no error but the error was -> %v", err) if err := provider.Present(rfc2136TestDomain, "", "1234d=="); err != nil {
t.Errorf("Expected Present() to return no error but the error was -> %v", err)
} }
rcvMsg := <-reqChan rcvMsg := <-reqChan

View file

@ -40,13 +40,15 @@ func NewDNSProviderRoute53(awsAccessKey, awsSecretKey, awsRegionName string) (*D
return &DNSProviderRoute53{client: client}, nil return &DNSProviderRoute53{client: client}, nil
} }
// CreateTXTRecord creates a TXT record using the specified parameters // Present creates a TXT record using the specified parameters
func (r *DNSProviderRoute53) CreateTXTRecord(fqdn, value string, ttl int) error { func (r *DNSProviderRoute53) Present(domain, token, keyAuth string) error {
fqdn, value, ttl := DNS01Record(domain, keyAuth)
return r.changeRecord("UPSERT", fqdn, value, ttl) return r.changeRecord("UPSERT", fqdn, value, ttl)
} }
// RemoveTXTRecord removes the TXT record matching the specified parameters // CleanUp removes the TXT record matching the specified parameters
func (r *DNSProviderRoute53) RemoveTXTRecord(fqdn, value string, ttl int) error { func (r *DNSProviderRoute53) CleanUp(domain, token, keyAuth string) error {
fqdn, value, ttl := DNS01Record(domain, keyAuth)
return r.changeRecord("DELETE", fqdn, value, ttl) return r.changeRecord("DELETE", fqdn, value, ttl)
} }

View file

@ -108,19 +108,22 @@ func TestNewDNSProviderRoute53InvalidRegionErr(t *testing.T) {
assert.EqualError(t, err, "Invalid AWS region name us-east-3") assert.EqualError(t, err, "Invalid AWS region name us-east-3")
} }
func TestRoute53CreateTXTRecord(t *testing.T) { func TestRoute53Present(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
testServer := makeRoute53TestServer() testServer := makeRoute53TestServer()
provider := makeRoute53Provider(testServer) provider := makeRoute53Provider(testServer)
testServer.ResponseMap(2, serverResponseMap) testServer.ResponseMap(2, serverResponseMap)
err := provider.CreateTXTRecord("_acme-challenge.123.example.com.", "123456d==", 120) domain := "example.com"
assert.NoError(err, "Expected CreateTXTRecord to return no error") keyAuth := "123456d=="
err := provider.Present(domain, "", keyAuth)
assert.NoError(err, "Expected Present to return no error")
httpReqs := testServer.WaitRequests(2) httpReqs := testServer.WaitRequests(2)
httpReq := httpReqs[1] httpReq := httpReqs[1]
assert.Equal("/2013-04-01/hostedzone/Z2K123214213123/rrset", httpReq.URL.Path, assert.Equal("/2013-04-01/hostedzone/Z2K123214213123/rrset", httpReq.URL.Path,
"Expected CreateTXTRecord to select the correct hostedzone") "Expected Present to select the correct hostedzone")
} }

View file

@ -2,73 +2,44 @@ package acme
import ( import (
"fmt" "fmt"
"net" "log"
"net/http"
"strings"
) )
type httpChallenge struct { type httpChallenge struct {
jws *jws jws *jws
validate validateFunc validate validateFunc
iface string provider ChallengeProvider
port string }
done chan bool
// HTTP01ChallengePath returns the URL path for the `http-01` challenge
func HTTP01ChallengePath(token string) string {
return "/.well-known/acme-challenge/" + token
} }
func (s *httpChallenge) Solve(chlng challenge, domain string) error { func (s *httpChallenge) Solve(chlng challenge, domain string) error {
logf("[INFO][%s] acme: Trying to solve HTTP-01", domain) logf("[INFO][%s] acme: Trying to solve HTTP-01", domain)
s.done = make(chan bool)
// Generate the Key Authorization for the challenge // Generate the Key Authorization for the challenge
keyAuth, err := getKeyAuthorization(chlng.Token, &s.jws.privKey.PublicKey) keyAuth, err := getKeyAuthorization(chlng.Token, &s.jws.privKey.PublicKey)
if err != nil { if err != nil {
return err return err
} }
// Allow for CLI port override if s.provider == nil {
port := "80" s.provider = &httpChallengeServer{}
if s.port != "" {
port = s.port
} }
iface := "" err = s.provider.Present(domain, chlng.Token, keyAuth)
if s.iface != "" {
iface = s.iface
}
listener, err := net.Listen("tcp", net.JoinHostPort(iface, port))
if err != nil { if err != nil {
return fmt.Errorf("Could not start HTTP server for challenge -> %v", err) return fmt.Errorf("Error presenting token %s", err)
} }
defer func() {
path := "/.well-known/acme-challenge/" + chlng.Token err := s.provider.CleanUp(domain, chlng.Token, keyAuth)
if err != nil {
go s.serve(listener, path, keyAuth, domain) log.Printf("Error cleaning up %s %v ", domain, err)
err = s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
listener.Close()
<-s.done
return err
} }
}()
func (s *httpChallenge) serve(listener net.Listener, path, keyAuth, domain string) { return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
// The handler validates the HOST header and request type.
// For validation it then writes the token the server returned with the challenge
mux := http.NewServeMux()
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.Host, domain) && r.Method == "GET" {
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte(keyAuth))
logf("[INFO][%s] Served key authentication", domain)
} else {
logf("[INFO] Received request for domain %s with method %s", r.Host, r.Method)
w.Write([]byte("TEST"))
}
})
http.Serve(listener, mux)
s.done <- true
} }

View file

@ -0,0 +1,63 @@
package acme
import (
"fmt"
"net"
"net/http"
"strings"
)
// httpChallengeServer implements ChallengeProvider for `http-01` challenge
type httpChallengeServer struct {
iface string
port string
done chan bool
listener net.Listener
}
// Present makes the token available at `HTTP01ChallengePath(token)`
func (s *httpChallengeServer) Present(domain, token, keyAuth string) error {
if s.port == "" {
s.port = "80"
}
var err error
s.listener, err = net.Listen("tcp", net.JoinHostPort(s.iface, s.port))
if err != nil {
return fmt.Errorf("Could not start HTTP server for challenge -> %v", err)
}
s.done = make(chan bool)
go s.serve(domain, token, keyAuth)
return nil
}
func (s *httpChallengeServer) CleanUp(domain, token, keyAuth string) error {
if s.listener == nil {
return nil
}
s.listener.Close()
<-s.done
return nil
}
func (s *httpChallengeServer) serve(domain, token, keyAuth string) {
path := HTTP01ChallengePath(token)
// The handler validates the HOST header and request type.
// For validation it then writes the token the server returned with the challenge
mux := http.NewServeMux()
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.Host, domain) && r.Method == "GET" {
w.Header().Add("Content-Type", "text/plain")
w.Write([]byte(keyAuth))
logf("[INFO][%s] Served key authentication", domain)
} else {
logf("[INFO] Received request for domain %s with method %s", r.Host, r.Method)
w.Write([]byte("TEST"))
}
})
http.Serve(s.listener, mux)
s.done <- true
}

View file

@ -10,7 +10,7 @@ import (
func TestHTTPChallenge(t *testing.T) { func TestHTTPChallenge(t *testing.T) {
privKey, _ := generatePrivateKey(rsakey, 512) privKey, _ := generatePrivateKey(rsakey, 512)
j := &jws{privKey: privKey.(*rsa.PrivateKey)} j := &jws{privKey: privKey.(*rsa.PrivateKey)}
clientChallenge := challenge{Type: "http-01", Token: "http1"} clientChallenge := challenge{Type: HTTP01, Token: "http1"}
mockValidate := func(_ *jws, _, _ string, chlng challenge) error { mockValidate := func(_ *jws, _, _ string, chlng challenge) error {
uri := "http://localhost:23457/.well-known/acme-challenge/" + chlng.Token uri := "http://localhost:23457/.well-known/acme-challenge/" + chlng.Token
resp, err := httpGet(uri) resp, err := httpGet(uri)
@ -35,7 +35,7 @@ func TestHTTPChallenge(t *testing.T) {
return nil return nil
} }
solver := &httpChallenge{jws: j, validate: mockValidate, port: "23457"} solver := &httpChallenge{jws: j, validate: mockValidate, provider: &httpChallengeServer{port: "23457"}}
if err := solver.Solve(clientChallenge, "localhost:23457"); err != nil { if err := solver.Solve(clientChallenge, "localhost:23457"); err != nil {
t.Errorf("Solve error: got %v, want nil", err) t.Errorf("Solve error: got %v, want nil", err)
@ -45,8 +45,8 @@ func TestHTTPChallenge(t *testing.T) {
func TestHTTPChallengeInvalidPort(t *testing.T) { func TestHTTPChallengeInvalidPort(t *testing.T) {
privKey, _ := generatePrivateKey(rsakey, 128) privKey, _ := generatePrivateKey(rsakey, 128)
j := &jws{privKey: privKey.(*rsa.PrivateKey)} j := &jws{privKey: privKey.(*rsa.PrivateKey)}
clientChallenge := challenge{Type: "http-01", Token: "http2"} clientChallenge := challenge{Type: HTTP01, Token: "http2"}
solver := &httpChallenge{jws: j, validate: stubValidate, port: "123456"} solver := &httpChallenge{jws: j, validate: stubValidate, provider: &httpChallengeServer{port: "123456"}}
if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil {
t.Errorf("Solve error: got %v, want error", err) t.Errorf("Solve error: got %v, want error", err)

View file

@ -82,7 +82,7 @@ type validationRecord struct {
type challenge struct { type challenge struct {
Resource string `json:"resource,omitempty"` Resource string `json:"resource,omitempty"`
Type string `json:"type,omitempty"` Type Challenge `json:"type,omitempty"`
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
URI string `json:"uri,omitempty"` URI string `json:"uri,omitempty"`
Token string `json:"token,omitempty"` Token string `json:"token,omitempty"`

8
acme/provider.go Normal file
View file

@ -0,0 +1,8 @@
package acme
// ChallengeProvider presents the solution to a challenge available to be solved
// CleanUp will be called by the challenge if Present ends in a non-error state.
type ChallengeProvider interface {
Present(domain, token, keyAuth string) error
CleanUp(domain, token, keyAuth string) error
}

View file

@ -6,15 +6,13 @@ import (
"crypto/tls" "crypto/tls"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "log"
"net/http"
) )
type tlsSNIChallenge struct { type tlsSNIChallenge struct {
jws *jws jws *jws
validate validateFunc validate validateFunc
iface string provider ChallengeProvider
port string
} }
func (t *tlsSNIChallenge) Solve(chlng challenge, domain string) error { func (t *tlsSNIChallenge) Solve(chlng challenge, domain string) error {
@ -29,41 +27,25 @@ func (t *tlsSNIChallenge) Solve(chlng challenge, domain string) error {
return err return err
} }
cert, err := t.generateCertificate(keyAuth) if t.provider == nil {
t.provider = &tlsSNIChallengeServer{}
}
err = t.provider.Present(domain, chlng.Token, keyAuth)
if err != nil { if err != nil {
return err return fmt.Errorf("Error presenting token %s", err)
} }
defer func() {
// Allow for CLI port override err := t.provider.CleanUp(domain, chlng.Token, keyAuth)
port := "443"
if t.port != "" {
port = t.port
}
iface := ""
if t.iface != "" {
iface = t.iface
}
tlsConf := new(tls.Config)
tlsConf.Certificates = []tls.Certificate{cert}
listener, err := tls.Listen("tcp", net.JoinHostPort(iface, port), tlsConf)
if err != nil { if err != nil {
return fmt.Errorf("Could not start HTTPS server for challenge -> %v", err) log.Printf("Error cleaning up %s %v ", domain, err)
} }
defer listener.Close() }()
go http.Serve(listener, nil)
return t.validate(t.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth}) return t.validate(t.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth})
} }
func (t *tlsSNIChallenge) generateCertificate(keyAuth string) (tls.Certificate, error) { // TLSSNI01ChallengeCert returns a certificate for the `tls-sni-01` challenge
func TLSSNI01ChallengeCert(keyAuth string) (tls.Certificate, error) {
zBytes := sha256.Sum256([]byte(keyAuth))
z := hex.EncodeToString(zBytes[:sha256.Size])
// generate a new RSA key for the certificates // generate a new RSA key for the certificates
tempPrivKey, err := generatePrivateKey(rsakey, 2048) tempPrivKey, err := generatePrivateKey(rsakey, 2048)
if err != nil { if err != nil {
@ -72,6 +54,8 @@ func (t *tlsSNIChallenge) generateCertificate(keyAuth string) (tls.Certificate,
rsaPrivKey := tempPrivKey.(*rsa.PrivateKey) rsaPrivKey := tempPrivKey.(*rsa.PrivateKey)
rsaPrivPEM := pemEncode(rsaPrivKey) rsaPrivPEM := pemEncode(rsaPrivKey)
zBytes := sha256.Sum256([]byte(keyAuth))
z := hex.EncodeToString(zBytes[:sha256.Size])
domain := fmt.Sprintf("%s.%s.acme.invalid", z[:32], z[32:]) domain := fmt.Sprintf("%s.%s.acme.invalid", z[:32], z[32:])
tempCertPEM, err := generatePemCert(rsaPrivKey, domain) tempCertPEM, err := generatePemCert(rsaPrivKey, domain)
if err != nil { if err != nil {

View file

@ -0,0 +1,52 @@
package acme
import (
"crypto/tls"
"fmt"
"net"
"net/http"
)
// tlsSNIChallengeServer implements ChallengeProvider for `TLS-SNI-01` challenge
type tlsSNIChallengeServer struct {
iface string
port string
done chan bool
listener net.Listener
}
// Present makes the keyAuth available as a cert
func (s *tlsSNIChallengeServer) Present(domain, token, keyAuth string) error {
if s.port == "" {
s.port = "443"
}
cert, err := TLSSNI01ChallengeCert(keyAuth)
if err != nil {
return err
}
tlsConf := new(tls.Config)
tlsConf.Certificates = []tls.Certificate{cert}
s.listener, err = tls.Listen("tcp", net.JoinHostPort(s.iface, s.port), tlsConf)
if err != nil {
return fmt.Errorf("Could not start HTTPS server for challenge -> %v", err)
}
s.done = make(chan bool)
go func() {
http.Serve(s.listener, nil)
s.done <- true
}()
return nil
}
func (s *tlsSNIChallengeServer) CleanUp(domain, token, keyAuth string) error {
if s.listener == nil {
return nil
}
s.listener.Close()
<-s.done
return nil
}

View file

@ -13,7 +13,7 @@ import (
func TestTLSSNIChallenge(t *testing.T) { func TestTLSSNIChallenge(t *testing.T) {
privKey, _ := generatePrivateKey(rsakey, 512) privKey, _ := generatePrivateKey(rsakey, 512)
j := &jws{privKey: privKey.(*rsa.PrivateKey)} j := &jws{privKey: privKey.(*rsa.PrivateKey)}
clientChallenge := challenge{Type: "tls-sni-01", Token: "tlssni1"} clientChallenge := challenge{Type: TLSSNI01, Token: "tlssni1"}
mockValidate := func(_ *jws, _, _ string, chlng challenge) error { mockValidate := func(_ *jws, _, _ string, chlng challenge) error {
conn, err := tls.Dial("tcp", "localhost:23457", &tls.Config{ conn, err := tls.Dial("tcp", "localhost:23457", &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
@ -43,7 +43,7 @@ func TestTLSSNIChallenge(t *testing.T) {
return nil return nil
} }
solver := &tlsSNIChallenge{jws: j, validate: mockValidate, port: "23457"} solver := &tlsSNIChallenge{jws: j, validate: mockValidate, provider: &tlsSNIChallengeServer{port: "23457"}}
if err := solver.Solve(clientChallenge, "localhost:23457"); err != nil { if err := solver.Solve(clientChallenge, "localhost:23457"); err != nil {
t.Errorf("Solve error: got %v, want nil", err) t.Errorf("Solve error: got %v, want nil", err)
@ -53,8 +53,8 @@ func TestTLSSNIChallenge(t *testing.T) {
func TestTLSSNIChallengeInvalidPort(t *testing.T) { func TestTLSSNIChallengeInvalidPort(t *testing.T) {
privKey, _ := generatePrivateKey(rsakey, 128) privKey, _ := generatePrivateKey(rsakey, 128)
j := &jws{privKey: privKey.(*rsa.PrivateKey)} j := &jws{privKey: privKey.(*rsa.PrivateKey)}
clientChallenge := challenge{Type: "tls-sni-01", Token: "tlssni2"} clientChallenge := challenge{Type: TLSSNI01, Token: "tlssni2"}
solver := &tlsSNIChallenge{jws: j, validate: stubValidate, port: "123456"} solver := &tlsSNIChallenge{jws: j, validate: stubValidate, provider: &tlsSNIChallengeServer{port: "123456"}}
if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil {
t.Errorf("Solve error: got %v, want error", err) t.Errorf("Solve error: got %v, want error", err)

View file

@ -7,6 +7,7 @@ import (
"strings" "strings"
"github.com/codegangsta/cli" "github.com/codegangsta/cli"
"github.com/xenolf/lego/acme"
) )
// Configuration type from CLI and config files. // Configuration type from CLI and config files.
@ -24,8 +25,11 @@ func (c *Configuration) RsaBits() int {
return c.context.GlobalInt("rsa-key-size") return c.context.GlobalInt("rsa-key-size")
} }
func (c *Configuration) ExcludedSolvers() []string { func (c *Configuration) ExcludedSolvers() (cc []acme.Challenge) {
return c.context.GlobalStringSlice("exclude") for _, s := range c.context.GlobalStringSlice("exclude") {
cc = append(cc, acme.Challenge(s))
}
return
} }
// ServerPath returns the OS dependent path to the data for a specific CA // ServerPath returns the OS dependent path to the data for a specific CA