Fix acme tests.
This commit is contained in:
parent
ba499eeb2a
commit
2ab7dc6f9d
3 changed files with 171 additions and 131 deletions
|
@ -29,6 +29,18 @@ import (
|
|||
"github.com/smallstep/assert"
|
||||
)
|
||||
|
||||
type mockClient struct {
|
||||
get func(url string) (*http.Response, error)
|
||||
lookupTxt func(name string) ([]string, error)
|
||||
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
|
||||
}
|
||||
|
||||
func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) }
|
||||
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
|
||||
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return m.tlsDial(network, addr, config)
|
||||
}
|
||||
|
||||
func Test_storeError(t *testing.T) {
|
||||
type test struct {
|
||||
ch *Challenge
|
||||
|
@ -229,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) {
|
|||
func TestChallenge_Validate(t *testing.T) {
|
||||
type test struct {
|
||||
ch *Challenge
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
srv *httptest.Server
|
||||
|
@ -273,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -309,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -344,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -381,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -416,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
}
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -466,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -493,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) {
|
|||
defer tc.srv.Close()
|
||||
}
|
||||
|
||||
if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
@ -524,7 +537,7 @@ func (errReader) Close() error {
|
|||
|
||||
func TestHTTP01Validate(t *testing.T) {
|
||||
type test struct {
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
ch *Challenge
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
|
@ -541,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -575,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -608,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: errReader(0),
|
||||
|
@ -645,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Body: errReader(0),
|
||||
|
@ -681,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: errReader(0),
|
||||
}, nil
|
||||
|
@ -704,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
jwk.Key = "foo"
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||
}, nil
|
||||
|
@ -730,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||
}, nil
|
||||
|
@ -772,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString("foo")),
|
||||
}, nil
|
||||
|
@ -815,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
|
||||
}, nil
|
||||
|
@ -857,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
HTTPGet: func(url string) (*http.Response, error) {
|
||||
vc: &mockClient{
|
||||
get: func(url string) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
|
||||
}, nil
|
||||
|
@ -887,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
@ -911,7 +925,7 @@ func TestDNS01Validate(t *testing.T) {
|
|||
fulldomain := "*.zap.internal"
|
||||
domain := strings.TrimPrefix(fulldomain, "*.")
|
||||
type test struct {
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
ch *Challenge
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
|
@ -928,8 +942,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -963,8 +977,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -1001,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo"}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1026,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", "bar"}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1068,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", "bar"}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1111,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", expected}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1156,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
LookupTxt: func(url string) ([]string, error) {
|
||||
vc: &mockClient{
|
||||
lookupTxt: func(url string) ([]string, error) {
|
||||
return []string{"foo", expected}, nil
|
||||
},
|
||||
},
|
||||
|
@ -1186,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
for name, run := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := run(t)
|
||||
if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
@ -1206,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error)
|
||||
|
||||
func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) {
|
||||
srv := httptest.NewUnstartedServer(http.NewServeMux())
|
||||
|
||||
|
@ -1309,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
}
|
||||
}
|
||||
type test struct {
|
||||
vo *ValidateChallengeOptions
|
||||
vc Client
|
||||
ch *Challenge
|
||||
jwk *jose.JSONWebKey
|
||||
db DB
|
||||
|
@ -1321,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
ch := makeTLSCh()
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -1351,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
ch := makeTLSCh()
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return nil, errors.New("force")
|
||||
},
|
||||
},
|
||||
|
@ -1384,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1413,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.Client(&noopConn{}, config), nil
|
||||
},
|
||||
},
|
||||
|
@ -1443,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.Client(&noopConn{}, config), nil
|
||||
},
|
||||
},
|
||||
|
@ -1479,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
|
||||
},
|
||||
},
|
||||
|
@ -1516,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
vc: &mockClient{
|
||||
tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
|
||||
},
|
||||
},
|
||||
|
@ -1562,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1605,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1649,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1692,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1736,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
srv: srv,
|
||||
jwk: jwk,
|
||||
|
@ -1758,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1797,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1841,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1884,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1924,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -1963,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2008,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2054,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2100,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2144,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2189,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2226,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
|
||||
return test{
|
||||
ch: ch,
|
||||
vo: &ValidateChallengeOptions{
|
||||
TLSDial: tlsDial,
|
||||
vc: &mockClient{
|
||||
tlsDial: tlsDial,
|
||||
},
|
||||
db: &MockDB{
|
||||
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
|
||||
|
@ -2253,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) {
|
|||
defer tc.srv.Close()
|
||||
}
|
||||
|
||||
if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil {
|
||||
ctx := NewClientContext(context.Background(), tc.vc)
|
||||
if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
switch k := err.(type) {
|
||||
case *Error:
|
||||
|
|
|
@ -206,6 +206,11 @@ func (l *linker) Middleware(next http.Handler) http.Handler {
|
|||
|
||||
// GetLink is a helper for GetLinkExplicit.
|
||||
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
|
||||
var name string
|
||||
if p, ok := ProvisionerFromContext(ctx); ok {
|
||||
name = p.GetName()
|
||||
}
|
||||
|
||||
var u url.URL
|
||||
if baseURL := baseURLFromContext(ctx); baseURL != nil {
|
||||
u = *baseURL
|
||||
|
@ -217,8 +222,7 @@ func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) st
|
|||
u.Host = l.dns
|
||||
}
|
||||
|
||||
p := MustProvisionerFromContext(ctx)
|
||||
u.Path = l.prefix + GetUnescapedPathSuffix(typ, p.GetName(), inputs...)
|
||||
u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
|
|
|
@ -5,16 +5,34 @@ import (
|
|||
"fmt"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
|
||||
dns := "ca.smallstep.com"
|
||||
prefix := "acme"
|
||||
linker := NewLinker(dns, prefix)
|
||||
func mockProvisioner(t *testing.T) Provisioner {
|
||||
t.Helper()
|
||||
var defaultDisableRenewal = false
|
||||
|
||||
getPath := linker.GetUnescapedPathSuffix
|
||||
// Initialize provisioners
|
||||
p := &provisioner.ACME{
|
||||
Type: "ACME",
|
||||
Name: "test@acme-<test>provisioner.com",
|
||||
}
|
||||
if err := p.Init(provisioner.Config{Claims: provisioner.Claims{
|
||||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}}); err != nil {
|
||||
fmt.Printf("%v", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func TestGetUnescapedPathSuffix(t *testing.T) {
|
||||
getPath := GetUnescapedPathSuffix
|
||||
|
||||
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
|
||||
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
|
||||
|
@ -31,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLinker_DNS(t *testing.T) {
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
type test struct {
|
||||
name string
|
||||
dns string
|
||||
|
@ -116,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) {
|
|||
linker := NewLinker(dns, prefix)
|
||||
id := "1234"
|
||||
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
escProvName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
// No provisioner and no BaseURL from request
|
||||
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
|
||||
// Provisioner: yes, BaseURL: no
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
|
||||
|
||||
// Provisioner: no, BaseURL: yes
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
||||
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
|
||||
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
|
||||
|
@ -162,10 +180,10 @@ func TestLinker_GetLink(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkOrder(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
oid := "orderID"
|
||||
certID := "certID"
|
||||
|
@ -227,10 +245,10 @@ func TestLinker_LinkOrder(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkAccount(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
accID := "accountID"
|
||||
linkerPrefix := "acme"
|
||||
|
@ -259,10 +277,10 @@ func TestLinker_LinkAccount(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkChallenge(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
chID := "chID"
|
||||
azID := "azID"
|
||||
|
@ -292,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkAuthorization(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
chID0 := "chID-0"
|
||||
chID1 := "chID-1"
|
||||
|
@ -334,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) {
|
|||
|
||||
func TestLinker_LinkOrdersByAccountID(t *testing.T) {
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
prov := newProv()
|
||||
prov := mockProvisioner(t)
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, provisionerContextKey, prov)
|
||||
ctx := NewProvisionerContext(context.Background(), prov)
|
||||
ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
|
||||
|
||||
linkerPrefix := "acme"
|
||||
l := NewLinker("dns", linkerPrefix)
|
||||
|
|
Loading…
Reference in a new issue