From 1c6f67f47a267af445b3636fdc0d45f853c967a9 Mon Sep 17 00:00:00 2001 From: Danek Duvall Date: Tue, 12 Feb 2019 08:36:44 -0800 Subject: [PATCH] Add a mechanism to wrap a PreCheckFunc (#783) --- challenge/dns01/dns_challenge.go | 2 +- challenge/dns01/dns_challenge_test.go | 46 +++++++++++++++------------ challenge/dns01/precheck.go | 29 ++++++++++++++--- 3 files changed, 51 insertions(+), 26 deletions(-) diff --git a/challenge/dns01/dns_challenge.go b/challenge/dns01/dns_challenge.go index 68c79589..ace01378 100644 --- a/challenge/dns01/dns_challenge.go +++ b/challenge/dns01/dns_challenge.go @@ -127,7 +127,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error { log.Infof("[%s] acme: Checking DNS record propagation using %+v", domain, recursiveNameservers) err = wait.For("propagation", timeout, interval, func() (bool, error) { - stop, errP := c.preCheck.call(fqdn, value) + stop, errP := c.preCheck.call(domain, fqdn, value) if !stop || errP != nil { log.Infof("[%s] acme: Waiting for DNS record propagation.", domain) } diff --git a/challenge/dns01/dns_challenge_test.go b/challenge/dns01/dns_challenge_test.go index 93bec0ce..172c0146 100644 --- a/challenge/dns01/dns_challenge_test.go +++ b/challenge/dns01/dns_challenge_test.go @@ -44,20 +44,20 @@ func TestChallenge_PreSolve(t *testing.T) { testCases := []struct { desc string validate ValidateFunc - preCheck PreCheckFunc + preCheck WrapPreCheckFunc provider challenge.Provider expectError bool }{ { desc: "success", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{}, }, { desc: "validate fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{ present: nil, cleanUp: nil, @@ -66,7 +66,7 @@ func TestChallenge_PreSolve(t *testing.T) { { desc: "preCheck fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") }, provider: &providerTimeoutMock{ timeout: 2 * time.Second, interval: 500 * time.Millisecond, @@ -75,7 +75,7 @@ func TestChallenge_PreSolve(t *testing.T) { { desc: "present fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{ present: errors.New("OOPS"), }, @@ -84,7 +84,7 @@ func TestChallenge_PreSolve(t *testing.T) { { desc: "cleanUp fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{ cleanUp: errors.New("OOPS"), }, @@ -94,7 +94,7 @@ func TestChallenge_PreSolve(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck)) + chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck)) authz := acme.Authorization{ Identifier: acme.Identifier{ @@ -128,20 +128,20 @@ func TestChallenge_Solve(t *testing.T) { testCases := []struct { desc string validate ValidateFunc - preCheck PreCheckFunc + preCheck WrapPreCheckFunc provider challenge.Provider expectError bool }{ { desc: "success", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{}, }, { desc: "validate fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{ present: nil, cleanUp: nil, @@ -151,7 +151,7 @@ func TestChallenge_Solve(t *testing.T) { { desc: "preCheck fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") }, provider: &providerTimeoutMock{ timeout: 2 * time.Second, interval: 500 * time.Millisecond, @@ -161,7 +161,7 @@ func TestChallenge_Solve(t *testing.T) { { desc: "present fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{ present: errors.New("OOPS"), }, @@ -169,7 +169,7 @@ func TestChallenge_Solve(t *testing.T) { { desc: "cleanUp fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{ cleanUp: errors.New("OOPS"), }, @@ -179,7 +179,11 @@ func TestChallenge_Solve(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck)) + var options []ChallengeOption + if test.preCheck != nil { + options = append(options, WrapPreCheck(test.preCheck)) + } + chlg := NewChallenge(core, test.validate, test.provider, options...) authz := acme.Authorization{ Identifier: acme.Identifier{ @@ -213,20 +217,20 @@ func TestChallenge_CleanUp(t *testing.T) { testCases := []struct { desc string validate ValidateFunc - preCheck PreCheckFunc + preCheck WrapPreCheckFunc provider challenge.Provider expectError bool }{ { desc: "success", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{}, }, { desc: "validate fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{ present: nil, cleanUp: nil, @@ -235,7 +239,7 @@ func TestChallenge_CleanUp(t *testing.T) { { desc: "preCheck fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") }, provider: &providerTimeoutMock{ timeout: 2 * time.Second, interval: 500 * time.Millisecond, @@ -244,7 +248,7 @@ func TestChallenge_CleanUp(t *testing.T) { { desc: "present fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{ present: errors.New("OOPS"), }, @@ -252,7 +256,7 @@ func TestChallenge_CleanUp(t *testing.T) { { desc: "cleanUp fail", validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil }, - preCheck: func(_, _ string) (bool, error) { return true, nil }, + preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil }, provider: &providerMock{ cleanUp: errors.New("OOPS"), }, @@ -263,7 +267,7 @@ func TestChallenge_CleanUp(t *testing.T) { for _, test := range testCases { t.Run(test.desc, func(t *testing.T) { - chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck)) + chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck)) authz := acme.Authorization{ Identifier: acme.Identifier{ diff --git a/challenge/dns01/precheck.go b/challenge/dns01/precheck.go index c3389110..00e09854 100644 --- a/challenge/dns01/precheck.go +++ b/challenge/dns01/precheck.go @@ -1,6 +1,7 @@ package dns01 import ( + "errors" "fmt" "net" "strings" @@ -11,11 +12,30 @@ import ( // PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready. type PreCheckFunc func(fqdn, value string) (bool, error) +// WrapPreCheckFunc wraps a PreCheckFunc in order to do extra operations before or after +// the main check, put it in a loop, etc. +type WrapPreCheckFunc func(domain, fqdn, value string, check PreCheckFunc) (bool, error) + +// WrapPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready. +func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption { + return func(chlg *Challenge) error { + chlg.preCheck.checkFunc = wrap + return nil + } +} + +// AddPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready. +// Deprecated: use WrapPreCheck instead. func AddPreCheck(preCheck PreCheckFunc) ChallengeOption { // Prevent race condition check := preCheck return func(chlg *Challenge) error { - chlg.preCheck.checkFunc = check + chlg.preCheck.checkFunc = func(_, fqdn, value string, _ PreCheckFunc) (bool, error) { + if check == nil { + return false, errors.New("invalid preCheck: preCheck is nil") + } + return check(fqdn, value) + } return nil } } @@ -29,7 +49,7 @@ func DisableCompletePropagationRequirement() ChallengeOption { type preCheck struct { // checks DNS propagation before notifying ACME that the DNS challenge is ready. - checkFunc PreCheckFunc + checkFunc WrapPreCheckFunc // require the TXT record to be propagated to all authoritative name servers requireCompletePropagation bool } @@ -40,11 +60,12 @@ func newPreCheck() preCheck { } } -func (p preCheck) call(fqdn, value string) (bool, error) { +func (p preCheck) call(domain, fqdn, value string) (bool, error) { if p.checkFunc == nil { return p.checkDNSPropagation(fqdn, value) } - return p.checkFunc(fqdn, value) + + return p.checkFunc(domain, fqdn, value, p.checkDNSPropagation) } // checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.