forked from TrueCloudLab/lego
Add a mechanism to wrap a PreCheckFunc (#783)
This commit is contained in:
parent
19303d3ac6
commit
1c6f67f47a
3 changed files with 51 additions and 26 deletions
challenge/dns01
|
@ -127,7 +127,7 @@ func (c *Challenge) Solve(authz acme.Authorization) error {
|
|||
log.Infof("[%s] acme: Checking DNS record propagation using %+v", domain, recursiveNameservers)
|
||||
|
||||
err = wait.For("propagation", timeout, interval, func() (bool, error) {
|
||||
stop, errP := c.preCheck.call(fqdn, value)
|
||||
stop, errP := c.preCheck.call(domain, fqdn, value)
|
||||
if !stop || errP != nil {
|
||||
log.Infof("[%s] acme: Waiting for DNS record propagation.", domain)
|
||||
}
|
||||
|
|
|
@ -44,20 +44,20 @@ func TestChallenge_PreSolve(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
validate ValidateFunc
|
||||
preCheck PreCheckFunc
|
||||
preCheck WrapPreCheckFunc
|
||||
provider challenge.Provider
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
desc: "success",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{},
|
||||
},
|
||||
{
|
||||
desc: "validate fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: nil,
|
||||
cleanUp: nil,
|
||||
|
@ -66,7 +66,7 @@ func TestChallenge_PreSolve(t *testing.T) {
|
|||
{
|
||||
desc: "preCheck fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
|
||||
provider: &providerTimeoutMock{
|
||||
timeout: 2 * time.Second,
|
||||
interval: 500 * time.Millisecond,
|
||||
|
@ -75,7 +75,7 @@ func TestChallenge_PreSolve(t *testing.T) {
|
|||
{
|
||||
desc: "present fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: errors.New("OOPS"),
|
||||
},
|
||||
|
@ -84,7 +84,7 @@ func TestChallenge_PreSolve(t *testing.T) {
|
|||
{
|
||||
desc: "cleanUp fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
cleanUp: errors.New("OOPS"),
|
||||
},
|
||||
|
@ -94,7 +94,7 @@ func TestChallenge_PreSolve(t *testing.T) {
|
|||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck))
|
||||
chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
|
||||
|
||||
authz := acme.Authorization{
|
||||
Identifier: acme.Identifier{
|
||||
|
@ -128,20 +128,20 @@ func TestChallenge_Solve(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
validate ValidateFunc
|
||||
preCheck PreCheckFunc
|
||||
preCheck WrapPreCheckFunc
|
||||
provider challenge.Provider
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
desc: "success",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{},
|
||||
},
|
||||
{
|
||||
desc: "validate fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: nil,
|
||||
cleanUp: nil,
|
||||
|
@ -151,7 +151,7 @@ func TestChallenge_Solve(t *testing.T) {
|
|||
{
|
||||
desc: "preCheck fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
|
||||
provider: &providerTimeoutMock{
|
||||
timeout: 2 * time.Second,
|
||||
interval: 500 * time.Millisecond,
|
||||
|
@ -161,7 +161,7 @@ func TestChallenge_Solve(t *testing.T) {
|
|||
{
|
||||
desc: "present fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: errors.New("OOPS"),
|
||||
},
|
||||
|
@ -169,7 +169,7 @@ func TestChallenge_Solve(t *testing.T) {
|
|||
{
|
||||
desc: "cleanUp fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
cleanUp: errors.New("OOPS"),
|
||||
},
|
||||
|
@ -179,7 +179,11 @@ func TestChallenge_Solve(t *testing.T) {
|
|||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck))
|
||||
var options []ChallengeOption
|
||||
if test.preCheck != nil {
|
||||
options = append(options, WrapPreCheck(test.preCheck))
|
||||
}
|
||||
chlg := NewChallenge(core, test.validate, test.provider, options...)
|
||||
|
||||
authz := acme.Authorization{
|
||||
Identifier: acme.Identifier{
|
||||
|
@ -213,20 +217,20 @@ func TestChallenge_CleanUp(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
validate ValidateFunc
|
||||
preCheck PreCheckFunc
|
||||
preCheck WrapPreCheckFunc
|
||||
provider challenge.Provider
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
desc: "success",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{},
|
||||
},
|
||||
{
|
||||
desc: "validate fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return errors.New("OOPS") },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: nil,
|
||||
cleanUp: nil,
|
||||
|
@ -235,7 +239,7 @@ func TestChallenge_CleanUp(t *testing.T) {
|
|||
{
|
||||
desc: "preCheck fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return false, errors.New("OOPS") },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return false, errors.New("OOPS") },
|
||||
provider: &providerTimeoutMock{
|
||||
timeout: 2 * time.Second,
|
||||
interval: 500 * time.Millisecond,
|
||||
|
@ -244,7 +248,7 @@ func TestChallenge_CleanUp(t *testing.T) {
|
|||
{
|
||||
desc: "present fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
present: errors.New("OOPS"),
|
||||
},
|
||||
|
@ -252,7 +256,7 @@ func TestChallenge_CleanUp(t *testing.T) {
|
|||
{
|
||||
desc: "cleanUp fail",
|
||||
validate: func(_ *api.Core, _ string, _ acme.Challenge) error { return nil },
|
||||
preCheck: func(_, _ string) (bool, error) { return true, nil },
|
||||
preCheck: func(_, _, _ string, _ PreCheckFunc) (bool, error) { return true, nil },
|
||||
provider: &providerMock{
|
||||
cleanUp: errors.New("OOPS"),
|
||||
},
|
||||
|
@ -263,7 +267,7 @@ func TestChallenge_CleanUp(t *testing.T) {
|
|||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
chlg := NewChallenge(core, test.validate, test.provider, AddPreCheck(test.preCheck))
|
||||
chlg := NewChallenge(core, test.validate, test.provider, WrapPreCheck(test.preCheck))
|
||||
|
||||
authz := acme.Authorization{
|
||||
Identifier: acme.Identifier{
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package dns01
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
@ -11,11 +12,30 @@ import (
|
|||
// PreCheckFunc checks DNS propagation before notifying ACME that the DNS challenge is ready.
|
||||
type PreCheckFunc func(fqdn, value string) (bool, error)
|
||||
|
||||
// WrapPreCheckFunc wraps a PreCheckFunc in order to do extra operations before or after
|
||||
// the main check, put it in a loop, etc.
|
||||
type WrapPreCheckFunc func(domain, fqdn, value string, check PreCheckFunc) (bool, error)
|
||||
|
||||
// WrapPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready.
|
||||
func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption {
|
||||
return func(chlg *Challenge) error {
|
||||
chlg.preCheck.checkFunc = wrap
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddPreCheck Allow to define checks before notifying ACME that the DNS challenge is ready.
|
||||
// Deprecated: use WrapPreCheck instead.
|
||||
func AddPreCheck(preCheck PreCheckFunc) ChallengeOption {
|
||||
// Prevent race condition
|
||||
check := preCheck
|
||||
return func(chlg *Challenge) error {
|
||||
chlg.preCheck.checkFunc = check
|
||||
chlg.preCheck.checkFunc = func(_, fqdn, value string, _ PreCheckFunc) (bool, error) {
|
||||
if check == nil {
|
||||
return false, errors.New("invalid preCheck: preCheck is nil")
|
||||
}
|
||||
return check(fqdn, value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -29,7 +49,7 @@ func DisableCompletePropagationRequirement() ChallengeOption {
|
|||
|
||||
type preCheck struct {
|
||||
// checks DNS propagation before notifying ACME that the DNS challenge is ready.
|
||||
checkFunc PreCheckFunc
|
||||
checkFunc WrapPreCheckFunc
|
||||
// require the TXT record to be propagated to all authoritative name servers
|
||||
requireCompletePropagation bool
|
||||
}
|
||||
|
@ -40,11 +60,12 @@ func newPreCheck() preCheck {
|
|||
}
|
||||
}
|
||||
|
||||
func (p preCheck) call(fqdn, value string) (bool, error) {
|
||||
func (p preCheck) call(domain, fqdn, value string) (bool, error) {
|
||||
if p.checkFunc == nil {
|
||||
return p.checkDNSPropagation(fqdn, value)
|
||||
}
|
||||
return p.checkFunc(fqdn, value)
|
||||
|
||||
return p.checkFunc(domain, fqdn, value, p.checkDNSPropagation)
|
||||
}
|
||||
|
||||
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
|
||||
|
|
Loading…
Reference in a new issue