1
0
Fork 0
forked from TrueCloudLab/lego

Add a mechanism to wrap a PreCheckFunc ()

This commit is contained in:
Danek Duvall 2019-02-12 08:36:44 -08:00 committed by Ludovic Fernandez
parent 19303d3ac6
commit 1c6f67f47a
3 changed files with 51 additions and 26 deletions

View file

@ -127,7 +127,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
log.Infof("[%s] acme: Checking DNS record propagation using %+v", domain, recursiveNameservers)
err = wait.For("propagation", timeout, interval, func() (bool, error) {
stop, errP := c.preCheck.call(fqdn, value)
stop, errP := c.preCheck.call(domain, fqdn, value)
if !stop || errP != nil {
log.Infof("[%s] acme: Waiting for DNS record propagation.", domain)
}

View file

@ -44,20 +44,20 @@ func TestChallenge_PreSolve(t *testing.T) {
testCases := []struct {
desc string
validate ValidateFunc
preCheck PreCheckFunc
preCheck WrapPreCheckFunc
provider challenge.Provider
expectError bool
}{
{
desc: "success",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{},
},
{
desc: "validate fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: nil,
cleanUp: nil,
@ -66,7 +66,7 @@ func TestChallenge_PreSolve(t *testing.T) {
{
desc: "preCheck fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
provider: &providerTimeoutMock{
timeout: 2 * time.Second,
interval: 500 * time.Millisecond,
@ -75,7 +75,7 @@ func TestChallenge_PreSolve(t *testing.T) {
{
desc: "present fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: errors.New("OOPS"),
},
@ -84,7 +84,7 @@ func TestChallenge_PreSolve(t *testing.T) {
{
desc: "cleanUp fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
cleanUp: errors.New("OOPS"),
},
@ -94,7 +94,7 @@ func TestChallenge_PreSolve(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck))
chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
authz := acme.Authorization{
Identifier: acme.Identifier{
@ -128,20 +128,20 @@ func TestChallenge_Solve(t *testing.T) {
testCases := []struct {
desc string
validate ValidateFunc
preCheck PreCheckFunc
preCheck WrapPreCheckFunc
provider challenge.Provider
expectError bool
}{
{
desc: "success",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{},
},
{
desc: "validate fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: nil,
cleanUp: nil,
@ -151,7 +151,7 @@ func TestChallenge_Solve(t *testing.T) {
{
desc: "preCheck fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
provider: &providerTimeoutMock{
timeout: 2 * time.Second,
interval: 500 * time.Millisecond,
@ -161,7 +161,7 @@ func TestChallenge_Solve(t *testing.T) {
{
desc: "present fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: errors.New("OOPS"),
},
@ -169,7 +169,7 @@ func TestChallenge_Solve(t *testing.T) {
{
desc: "cleanUp fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
cleanUp: errors.New("OOPS"),
},
@ -179,7 +179,11 @@ func TestChallenge_Solve(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck))
var options []ChallengeOption
if test.preCheck != nil {
options = append(options, WrapPreCheck(test.preCheck))
}
chlg := NewChallenge(core, test.validate, test.provider, options...)
authz := acme.Authorization{
Identifier: acme.Identifier{
@ -213,20 +217,20 @@ func TestChallenge_CleanUp(t *testing.T) {
testCases := []struct {
desc string
validate ValidateFunc
preCheck PreCheckFunc
preCheck WrapPreCheckFunc
provider challenge.Provider
expectError bool
}{
{
desc: "success",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{},
},
{
desc: "validate fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: nil,
cleanUp: nil,
@ -235,7 +239,7 @@ func TestChallenge_CleanUp(t *testing.T) {
{
desc: "preCheck fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
provider: &providerTimeoutMock{
timeout: 2 * time.Second,
interval: 500 * time.Millisecond,
@ -244,7 +248,7 @@ func TestChallenge_CleanUp(t *testing.T) {
{
desc: "present fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
present: errors.New("OOPS"),
},
@ -252,7 +256,7 @@ func TestChallenge_CleanUp(t *testing.T) {
{
desc: "cleanUp fail",
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
preCheck: func(_, _ string) (bool, error) { return true, nil },
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
provider: &providerMock{
cleanUp: errors.New("OOPS"),
},
@ -263,7 +267,7 @@ func TestChallenge_CleanUp(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck))
chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
authz := acme.Authorization{
Identifier: acme.Identifier{

View file

@ -1,6 +1,7 @@
package dns01
import (
"errors"
"fmt"
"net"
"strings"
@ -11,11 +12,30 @@ import (
// PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready.
type PreCheckFunc func(fqdn, value string) (bool, error)
// WrapPreCheckFunc wraps a PreCheckFunc in order to do extra operations before or after
// the main check, put it in a loop, etc.
type WrapPreCheckFunc func(domain, fqdn, value string, check PreCheckFunc) (bool, error)
// WrapPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready.
func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption {
return func(chlg *Challenge) error {
chlg.preCheck.checkFunc = wrap
return nil
}
}
// AddPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready.
// Deprecated: use WrapPreCheck instead.
func AddPreCheck(preCheck PreCheckFunc) ChallengeOption {
// Prevent race condition
check := preCheck
return func(chlg *Challenge) error {
chlg.preCheck.checkFunc = check
chlg.preCheck.checkFunc = func(_, fqdn, value string, _ PreCheckFunc) (bool, error) {
if check == nil {
return false, errors.New("invalid preCheck: preCheck is nil")
}
return check(fqdn, value)
}
return nil
}
}
@ -29,7 +49,7 @@ func DisableCompletePropagationRequirement() ChallengeOption {
type preCheck struct {
// checks DNS propagation before notifying ACME that the DNS challenge is ready.
checkFunc PreCheckFunc
checkFunc WrapPreCheckFunc
// require the TXT record to be propagated to all authoritative name servers
requireCompletePropagation bool
}
@ -40,11 +60,12 @@ func newPreCheck() preCheck {
}
}
func (p preCheck) call(fqdn, value string) (bool, error) {
func (p preCheck) call(domain, fqdn, value string) (bool, error) {
if p.checkFunc == nil {
return p.checkDNSPropagation(fqdn, value)
}
return p.checkFunc(fqdn, value)
return p.checkFunc(domain, fqdn, value, p.checkDNSPropagation)
}
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.