Fix acme tests.

This commit is contained in:
Mariano Cano 2022-05-02 18:09:26 -07:00
parent ba499eeb2a
commit 2ab7dc6f9d
3 changed files with 171 additions and 131 deletions

View file

@ -29,6 +29,18 @@ import (
"github.com/smallstep/assert" "github.com/smallstep/assert"
) )
type mockClient struct {
get func(url string) (*http.Response, error)
lookupTxt func(name string) ([]string, error)
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
}
func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) }
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return m.tlsDial(network, addr, config)
}
func Test_storeError(t *testing.T) { func Test_storeError(t *testing.T) {
type test struct { type test struct {
ch *Challenge ch *Challenge
@ -229,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) {
func TestChallenge_Validate(t *testing.T) { func TestChallenge_Validate(t *testing.T) {
type test struct { type test struct {
ch *Challenge ch *Challenge
vo *ValidateChallengeOptions vc Client
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
srv *httptest.Server srv *httptest.Server
@ -273,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -309,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -344,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -381,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -416,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) {
} }
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -466,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -493,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) {
defer tc.srv.Close() defer tc.srv.Close()
} }
if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:
@ -524,7 +537,7 @@ func (errReader) Close() error {
func TestHTTP01Validate(t *testing.T) { func TestHTTP01Validate(t *testing.T) {
type test struct { type test struct {
vo *ValidateChallengeOptions vc Client
ch *Challenge ch *Challenge
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
@ -541,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -575,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -608,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
StatusCode: http.StatusBadRequest, StatusCode: http.StatusBadRequest,
Body: errReader(0), Body: errReader(0),
@ -645,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
StatusCode: http.StatusBadRequest, StatusCode: http.StatusBadRequest,
Body: errReader(0), Body: errReader(0),
@ -681,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: errReader(0), Body: errReader(0),
}, nil }, nil
@ -704,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) {
jwk.Key = "foo" jwk.Key = "foo"
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
@ -730,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
@ -772,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
@ -815,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
@ -857,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
@ -887,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:
@ -911,7 +925,7 @@ func TestDNS01Validate(t *testing.T) {
fulldomain := "*.zap.internal" fulldomain := "*.zap.internal"
domain := strings.TrimPrefix(fulldomain, "*.") domain := strings.TrimPrefix(fulldomain, "*.")
type test struct { type test struct {
vo *ValidateChallengeOptions vc Client
ch *Challenge ch *Challenge
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
@ -928,8 +942,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -963,8 +977,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -1001,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo"}, nil return []string{"foo"}, nil
}, },
}, },
@ -1026,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", "bar"}, nil return []string{"foo", "bar"}, nil
}, },
}, },
@ -1068,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", "bar"}, nil return []string{"foo", "bar"}, nil
}, },
}, },
@ -1111,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", expected}, nil return []string{"foo", expected}, nil
}, },
}, },
@ -1156,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", expected}, nil return []string{"foo", expected}, nil
}, },
}, },
@ -1186,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:
@ -1206,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) {
} }
} }
type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error)
func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) { func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) {
srv := httptest.NewUnstartedServer(http.NewServeMux()) srv := httptest.NewUnstartedServer(http.NewServeMux())
@ -1309,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) {
} }
} }
type test struct { type test struct {
vo *ValidateChallengeOptions vc Client
ch *Challenge ch *Challenge
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
@ -1321,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) {
ch := makeTLSCh() ch := makeTLSCh()
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -1351,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) {
ch := makeTLSCh() ch := makeTLSCh()
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -1384,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1413,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.Client(&noopConn{}, config), nil return tls.Client(&noopConn{}, config), nil
}, },
}, },
@ -1443,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.Client(&noopConn{}, config), nil return tls.Client(&noopConn{}, config), nil
}, },
}, },
@ -1479,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
}, },
}, },
@ -1516,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
}, },
}, },
@ -1562,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1605,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1649,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1692,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1736,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
srv: srv, srv: srv,
jwk: jwk, jwk: jwk,
@ -1758,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1797,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1841,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1884,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1924,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1963,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2008,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2054,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2100,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2144,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2189,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2226,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2253,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) {
defer tc.srv.Close() defer tc.srv.Close()
} }
if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:

View file

@ -206,6 +206,11 @@ func (l *linker) Middleware(next http.Handler) http.Handler {
// GetLink is a helper for GetLinkExplicit. // GetLink is a helper for GetLinkExplicit.
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
var name string
if p, ok := ProvisionerFromContext(ctx); ok {
name = p.GetName()
}
var u url.URL var u url.URL
if baseURL := baseURLFromContext(ctx); baseURL != nil { if baseURL := baseURLFromContext(ctx); baseURL != nil {
u = *baseURL u = *baseURL
@ -217,8 +222,7 @@ func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) st
u.Host = l.dns u.Host = l.dns
} }
p := MustProvisionerFromContext(ctx) u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...)
u.Path = l.prefix + GetUnescapedPathSuffix(typ, p.GetName(), inputs...)
return u.String() return u.String()
} }

View file

@ -5,16 +5,34 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"testing" "testing"
"time"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
) )
func TestLinker_GetUnescapedPathSuffix(t *testing.T) { func mockProvisioner(t *testing.T) Provisioner {
dns := "ca.smallstep.com" t.Helper()
prefix := "acme" var defaultDisableRenewal = false
linker := NewLinker(dns, prefix)
getPath := linker.GetUnescapedPathSuffix // Initialize provisioners
p := &provisioner.ACME{
Type: "ACME",
Name: "test@acme-<test>provisioner.com",
}
if err := p.Init(provisioner.Config{Claims: provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
}}); err != nil {
fmt.Printf("%v", err)
}
return p
}
func TestGetUnescapedPathSuffix(t *testing.T) {
getPath := GetUnescapedPathSuffix
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce") assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory") assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
@ -31,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
} }
func TestLinker_DNS(t *testing.T) { func TestLinker_DNS(t *testing.T) {
prov := newProv() prov := mockProvisioner(t)
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := NewProvisionerContext(context.Background(), prov)
type test struct { type test struct {
name string name string
dns string dns string
@ -116,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) {
linker := NewLinker(dns, prefix) linker := NewLinker(dns, prefix)
id := "1234" id := "1234"
prov := newProv() prov := mockProvisioner(t)
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
// No provisioner and no BaseURL from request // No provisioner and no BaseURL from request
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", "")) assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
// Provisioner: yes, BaseURL: no // Provisioner: yes, BaseURL: no
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
// Provisioner: no, BaseURL: yes // Provisioner: no, BaseURL: yes
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
@ -162,10 +180,10 @@ func TestLinker_GetLink(t *testing.T) {
func TestLinker_LinkOrder(t *testing.T) { func TestLinker_LinkOrder(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
oid := "orderID" oid := "orderID"
certID := "certID" certID := "certID"
@ -227,10 +245,10 @@ func TestLinker_LinkOrder(t *testing.T) {
func TestLinker_LinkAccount(t *testing.T) { func TestLinker_LinkAccount(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
accID := "accountID" accID := "accountID"
linkerPrefix := "acme" linkerPrefix := "acme"
@ -259,10 +277,10 @@ func TestLinker_LinkAccount(t *testing.T) {
func TestLinker_LinkChallenge(t *testing.T) { func TestLinker_LinkChallenge(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
chID := "chID" chID := "chID"
azID := "azID" azID := "azID"
@ -292,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) {
func TestLinker_LinkAuthorization(t *testing.T) { func TestLinker_LinkAuthorization(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
chID0 := "chID-0" chID0 := "chID-0"
chID1 := "chID-1" chID1 := "chID-1"
@ -334,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) {
func TestLinker_LinkOrdersByAccountID(t *testing.T) { func TestLinker_LinkOrdersByAccountID(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)