diff --git a/acme/challenge_test.go b/acme/challenge_test.go index 11b30961..29bd5a71 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -1,643 +1,27 @@ package acme -/* -var testOps = ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "", // will get set correctly depending on the "new.." method. - Value: "zap.internal", - }, -} +import ( + "bytes" + "context" + "crypto" + "encoding/base64" + "fmt" + "io/ioutil" + "net/http" + "testing" + "time" -func newDNSCh() (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newDNS01Challenge(mockdb, testOps) -} - -func newTLSALPNCh() (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newTLSALPN01Challenge(mockdb, testOps) -} - -func newHTTPCh() (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newHTTP01Challenge(mockdb, testOps) -} - -func newHTTPChWithServer(host string) (Challenge, error) { - mockdb := &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - } - return newHTTP01Challenge(mockdb, ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "", // will get set correctly depending on the "new.." method. - Value: host, - }, - }) -} - -func TestNewHTTP01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "http", - Value: "zap.internal", - }, - } - type test struct { - ops ChallengeOptions - db nosql.DB - err *Error - } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - }, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ch, err := newHTTP01Challenge(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "http-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") - } - } - }) - } -} - -func TestNewTLSALPN01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "http", - Value: "zap.internal", - }, - } - type test struct { - ops ChallengeOptions - db nosql.DB - err *Error - } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - }, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ch, err := newTLSALPN01Challenge(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "tls-alpn-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") - } - } - }) - } -} - -func TestNewDNS01Challenge(t *testing.T) { - ops := ChallengeOptions{ - AccountID: "accID", - AuthzID: "authzID", - Identifier: Identifier{ - Type: "dns", - Value: "zap.internal", - }, - } - type test struct { - ops ChallengeOptions - db nosql.DB - err *Error - } - tests := map[string]test{ - "fail/store-error": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - }, - "ok": { - ops: ops, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), true, nil - }, - }, - }, - } - for name, tc := range tests { - t.Run(name, func(t *testing.T) { - ch, err := newDNS01Challenge(tc.db, tc.ops) - if err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, ch.getAccountID(), ops.AccountID) - assert.Equals(t, ch.getAuthzID(), ops.AuthzID) - assert.Equals(t, ch.getType(), "dns-01") - assert.Equals(t, ch.getValue(), "zap.internal") - assert.Equals(t, ch.getStatus(), StatusPending) - - assert.True(t, ch.getValidated().IsZero()) - assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) - - assert.True(t, ch.getID() != "") - assert.True(t, ch.getToken() != "") - } - } - }) - } -} - -func TestChallengeToACME(t *testing.T) { - dir := newDirectory("ca.smallstep.com", "acme") - - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Validated = clock.Now() - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - - prov := newProv() - provName := url.PathEscape(prov.GetName()) - baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) - ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) - tests := map[string]challenge{ - "dns": dnsCh, - "http": httpCh, - "tls-alpn": tlsALPNCh, - } - for name, ch := range tests { - t.Run(name, func(t *testing.T) { - ach, err := ch.toACME(ctx, nil, dir) - assert.FatalError(t, err) - - assert.Equals(t, ach.Type, ch.getType()) - assert.Equals(t, ach.Status, ch.getStatus()) - assert.Equals(t, ach.Token, ch.getToken()) - assert.Equals(t, ach.URL, - fmt.Sprintf("%s/acme/%s/challenge/%s", - baseURL.String(), provName, ch.getID())) - assert.Equals(t, ach.ID, ch.getID()) - assert.Equals(t, ach.AuthzID, ch.getAuthzID()) - - if ach.Type == "http-01" { - v, err := time.Parse(time.RFC3339, ach.Validated) - assert.FatalError(t, err) - assert.Equals(t, v.String(), _httpCh.baseChallenge.Validated.String()) - } else { - assert.Equals(t, ach.Validated, "") - } - }) - } -} - -func TestChallengeSave(t *testing.T) { - type test struct { - ch challenge - old challenge - db nosql.DB - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/old-nil/swap-error": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), - } - }, - "fail/old-nil/swap-false": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return []byte("foo"), false, nil - }, - }, - err: ServerInternalErr(errors.New("error saving acme challenge; acme challenge has changed since last read")), - } - }, - "ok/old-nil": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - old: nil, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, nil) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, []byte(httpCh.getID()), key) - return []byte("foo"), true, nil - }, - }, - } - }, - "ok/old-not-nil": func(t *testing.T) test { - oldHTTPCh, err := newHTTPCh() - assert.FatalError(t, err) - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - - oldb, err := json.Marshal(oldHTTPCh) - assert.FatalError(t, err) - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - old: oldHTTPCh, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, old, oldb) - assert.Equals(t, b, newval) - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, []byte(httpCh.getID()), key) - return []byte("foo"), true, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if err := tc.ch.save(tc.db, tc.old); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - assert.Nil(t, tc.err) - } - }) - } -} - -func TestChallengeClone(t *testing.T) { - ch, err := newHTTPCh() - assert.FatalError(t, err) - - clone := ch.clone() - - assert.Equals(t, clone.getID(), ch.getID()) - assert.Equals(t, clone.getAccountID(), ch.getAccountID()) - assert.Equals(t, clone.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, clone.getStatus(), ch.getStatus()) - assert.Equals(t, clone.getToken(), ch.getToken()) - assert.Equals(t, clone.getCreated(), ch.getCreated()) - assert.Equals(t, clone.getValidated(), ch.getValidated()) - - clone.Status = StatusValid - - assert.NotEquals(t, clone.getStatus(), ch.getStatus()) -} - -func TestChallengeUnmarshal(t *testing.T) { - type test struct { - ch challenge - chb []byte - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/nil": func(t *testing.T) test { - return test{ - chb: nil, - err: ServerInternalErr(errors.New("error unmarshaling challenge type: unexpected end of JSON input")), - } - }, - "fail/unexpected-type-http": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Type = "foo" - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "fail/unexpected-type-alpn": func(t *testing.T) test { - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - _tlsALPNCh, ok := tlsALPNCh.(*tlsALPN01Challenge) - assert.Fatal(t, ok) - _tlsALPNCh.baseChallenge.Type = "foo" - b, err := json.Marshal(tlsALPNCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "fail/unexpected-type-dns": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - _dnsCh, ok := dnsCh.(*dns01Challenge) - assert.Fatal(t, ok) - _dnsCh.baseChallenge.Type = "foo" - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - chb: b, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "ok/dns": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - ch: dnsCh, - chb: b, - } - }, - "ok/http": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - chb: b, - } - }, - "ok/alpn": func(t *testing.T) test { - tlsALPNCh, err := newTLSALPNCh() - assert.FatalError(t, err) - b, err := json.Marshal(tlsALPNCh) - assert.FatalError(t, err) - return test{ - ch: tlsALPNCh, - chb: b, - } - }, - "ok/err": func(t *testing.T) test { - httpCh, err := newHTTPCh() - assert.FatalError(t, err) - _httpCh, ok := httpCh.(*http01Challenge) - assert.Fatal(t, ok) - _httpCh.baseChallenge.Error = ServerInternalErr(errors.New("force")).ToACME() - b, err := json.Marshal(httpCh) - assert.FatalError(t, err) - return test{ - ch: httpCh, - chb: b, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if ch, err := unmarshalChallenge(tc.chb); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.ch.getID(), ch.getID()) - assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) - assert.Equals(t, tc.ch.getToken(), ch.getToken()) - assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) - assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) - } - } - }) - } -} -func TestGetChallenge(t *testing.T) { - type test struct { - id string - db nosql.DB - ch challenge - err *Error - } - tests := map[string]func(t *testing.T) test{ - "fail/not-found": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, database.ErrNotFound - }, - }, - err: MalformedErr(errors.Errorf("challenge %s not found: not found", dnsCh.getID())), - } - }, - "fail/db-error": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - return nil, errors.New("force") - }, - }, - err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", dnsCh.getID())), - } - }, - "fail/unmarshal-error": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - _dnsCh, ok := dnsCh.(*dns01Challenge) - assert.Fatal(t, ok) - _dnsCh.baseChallenge.Type = "foo" - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(dnsCh.getID())) - return b, nil - }, - }, - err: ServerInternalErr(errors.New("unexpected challenge type foo")), - } - }, - "ok": func(t *testing.T) test { - dnsCh, err := newDNSCh() - assert.FatalError(t, err) - b, err := json.Marshal(dnsCh) - assert.FatalError(t, err) - return test{ - ch: dnsCh, - id: dnsCh.getID(), - db: &db.MockNoSQLDB{ - MGet: func(bucket, key []byte) ([]byte, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(dnsCh.getID())) - return b, nil - }, - }, - } - }, - } - for name, run := range tests { - t.Run(name, func(t *testing.T) { - tc := run(t) - if ch, err := getChallenge(tc.db, tc.id); err != nil { - if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) - } - } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.ch.getID(), ch.getID()) - assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) - assert.Equals(t, tc.ch.getToken(), ch.getToken()) - assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) - assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) - } - } - }) - } -} + "github.com/pkg/errors" + "github.com/smallstep/assert" + "go.step.sm/crypto/jose" +) func TestKeyAuthorization(t *testing.T) { type test struct { token string jwk *jose.JSONWebKey exp string - err *Error + err error } tests := map[string]func(t *testing.T) test{ "fail/jwk-thumbprint-error": func(t *testing.T) test { @@ -647,7 +31,7 @@ func TestKeyAuthorization(t *testing.T) { return test{ token: "1234", jwk: jwk, - err: ServerInternalErr(errors.Errorf("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), + err: errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), } }, "ok": func(t *testing.T) test { @@ -669,11 +53,7 @@ func TestKeyAuthorization(t *testing.T) { tc := run(t) if ka, err := KeyAuthorization(tc.token, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + assert.HasPrefix(t, err.Error(), tc.err.Error()) } } else { if assert.Nil(t, tc.err) { @@ -695,258 +75,357 @@ func (errReader) Close() error { func TestHTTP01Validate(t *testing.T) { type test struct { - vo validateOptions - ch challenge - res challenge + vo *ValidateChallengeOptions + ch *Challenge jwk *jose.JSONWebKey - db nosql.DB + db DB err *Error } tests := map[string]func(t *testing.T) test{ - "ok/status-already-valid": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - _ch, ok := ch.(*http01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusValid - return test{ - ch: ch, - res: ch, + "fail/http-get-error-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", } - }, - "ok/status-already-invalid": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - _ch, ok := ch.(*http01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Status = StatusInvalid - return test{ - ch: ch, - res: ch, - } - }, - "ok/http-get-error": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) - expErr := ConnectionErr(errors.Errorf("error doing http GET for url "+ - "http://zap.internal/.well-known/acme-challenge/%s: force", ch.getToken())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - res: ch, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok/http-get->=400": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "ok/http-get-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } - expErr := ConnectionErr(errors.Errorf("error doing http GET for url "+ - "http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.getToken())) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return nil, errors.New("force") + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "fail/http-get->=400-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, }, nil }, }, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - res: ch, + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "fail/read-body": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) - assert.FatalError(t, err) - jwk.Key = "foo" + "ok/http-get->=400": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadRequest, + }, nil + }, + }, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + + err := NewError(ErrorConnectionType, "error doing http GET for url http://zap.internal/.well-known/acme-challenge/%s with status code 400", ch.Token) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil + }, + }, + } + }, + "fail/read-body": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ Body: errReader(0), }, nil }, }, - jwk: jwk, - err: ServerInternalErr(errors.Errorf("error reading response "+ - "body for url http://zap.internal/.well-known/acme-challenge/%s: force", - ch.getToken())), + err: NewErrorISE("error reading response body for url http://zap.internal/.well-known/acme-challenge/%s: force", ch.Token), } }, - "fail/key-authorization-gen-error": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) + "fail/key-auth-gen-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) jwk.Key = "foo" return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ Body: ioutil.NopCloser(bytes.NewBufferString("foo")), }, nil }, }, jwk: jwk, - err: ServerInternalErr(errors.New("error generating JWK thumbprint: square/go-jose: unknown key type 'string'")), + err: NewErrorISE("error generating JWK thumbprint: square/go-jose: unknown key type 'string'"), } }, "ok/key-auth-mismatch": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) - - expErr := RejectedIdentifierErr(errors.Errorf("keyAuthorization does not match; "+ - "expected %s, but got foo", expKeyAuth)) - baseClone := ch.clone() - baseClone.Error = expErr.ToACME() - newCh := &http01Challenge{baseClone} - newb, err := json.Marshal(newCh) - assert.FatalError(t, err) - return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ Body: ioutil.NopCloser(bytes.NewBufferString("foo")), }, nil }, }, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - assert.Equals(t, newval, newb) - return nil, true, nil + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, + "keyAuthorization does not match; expected %s, but got foo", expKeyAuth) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return nil }, }, - res: ch, } }, - "fail/save-error": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) + "fail/key-auth-mismatch-store-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) return test{ ch: ch, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ - Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), + Body: ioutil.NopCloser(bytes.NewBufferString("foo")), }, nil }, }, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - return nil, false, errors.New("force") + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + err := NewError(ErrorRejectedIdentifierType, + "keyAuthorization does not match; expected %s, but got foo", expKeyAuth) + assert.HasPrefix(t, updch.Error.Err.Error(), err.Err.Error()) + assert.Equals(t, updch.Error.Type, err.Type) + assert.Equals(t, updch.Error.Detail, err.Detail) + assert.Equals(t, updch.Error.Status, err.Status) + assert.Equals(t, updch.Error.Detail, err.Detail) + return errors.New("force") }, }, - err: ServerInternalErr(errors.New("error saving acme challenge: force")), + err: NewErrorISE("failure saving error to acme challenge: force"), } }, - "ok": func(t *testing.T) test { - ch, err := newHTTPCh() - assert.FatalError(t, err) - _ch, ok := ch.(*http01Challenge) - assert.Fatal(t, ok) - _ch.baseChallenge.Error = MalformedErr(nil).ToACME() - oldb, err := json.Marshal(ch) - assert.FatalError(t, err) + "fail/update-challenge-error": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) assert.FatalError(t, err) - expKeyAuth, err := KeyAuthorization(ch.getToken(), jwk) + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) assert.FatalError(t, err) - - baseClone := ch.clone() - baseClone.Status = StatusValid - baseClone.Error = nil - newCh := &http01Challenge{baseClone} - return test{ - ch: ch, - res: newCh, - vo: validateOptions{ - httpGet: func(url string) (*http.Response, error) { + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { return &http.Response{ Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil }, }, jwk: jwk, - db: &db.MockNoSQLDB{ - MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { - assert.Equals(t, bucket, challengeTable) - assert.Equals(t, key, []byte(ch.getID())) - assert.Equals(t, old, oldb) - - httpCh, err := unmarshalChallenge(newval) + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Error, nil) + va, err := time.Parse(time.RFC3339, updch.ValidatedAt) assert.FatalError(t, err) - assert.Equals(t, httpCh.getStatus(), StatusValid) - assert.True(t, httpCh.getValidated().Before(time.Now().UTC().Add(time.Minute))) - assert.True(t, httpCh.getValidated().After(time.Now().UTC().Add(-1*time.Second))) + now := clock.Now() + assert.True(t, va.Add(-time.Minute).Before(now)) + assert.True(t, va.Add(time.Minute).After(now)) - baseClone.Validated = httpCh.getValidated() + return errors.New("force") + }, + }, + err: NewErrorISE("error updating challenge: force"), + } + }, + "ok": func(t *testing.T) test { + ch := &Challenge{ + ID: "chID", + AuthzID: "azID", + Token: "token", + Value: "zap.internal", + } - return nil, true, nil + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + + expKeyAuth, err := KeyAuthorization(ch.Token, jwk) + assert.FatalError(t, err) + return test{ + ch: ch, + vo: &ValidateChallengeOptions{ + HTTPGet: func(url string) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), + }, nil + }, + }, + jwk: jwk, + db: &MockDB{ + MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { + assert.Equals(t, updch.ID, ch.ID) + assert.Equals(t, updch.AuthzID, ch.AuthzID) + assert.Equals(t, updch.Token, ch.Token) + assert.Equals(t, updch.Value, ch.Value) + + assert.Equals(t, updch.Status, StatusValid) + assert.Equals(t, updch.Error, nil) + va, err := time.Parse(time.RFC3339, updch.ValidatedAt) + assert.FatalError(t, err) + now := clock.Now() + assert.True(t, va.Add(-time.Minute).Before(now)) + assert.True(t, va.Add(time.Minute).After(now)) + return nil }, }, } @@ -955,30 +434,27 @@ func TestHTTP01Validate(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if ch, err := tc.ch.validate(tc.db, tc.jwk, tc.vo); err != nil { + if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { if assert.NotNil(t, tc.err) { - ae, ok := err.(*Error) - assert.True(t, ok) - assert.HasPrefix(t, ae.Error(), tc.err.Error()) - assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) - assert.Equals(t, ae.Type, tc.err.Type) + switch k := err.(type) { + case *Error: + assert.Equals(t, k.Type, tc.err.Type) + assert.Equals(t, k.Detail, tc.err.Detail) + assert.Equals(t, k.Status, tc.err.Status) + assert.Equals(t, k.Err.Error(), tc.err.Err.Error()) + assert.Equals(t, k.Detail, tc.err.Detail) + default: + assert.FatalError(t, errors.New("unexpected error type")) + } } } else { - if assert.Nil(t, tc.err) { - assert.Equals(t, tc.res.getID(), ch.getID()) - assert.Equals(t, tc.res.getAccountID(), ch.getAccountID()) - assert.Equals(t, tc.res.getAuthzID(), ch.getAuthzID()) - assert.Equals(t, tc.res.getStatus(), ch.getStatus()) - assert.Equals(t, tc.res.getToken(), ch.getToken()) - assert.Equals(t, tc.res.getCreated(), ch.getCreated()) - assert.Equals(t, tc.res.getValidated(), ch.getValidated()) - assert.Equals(t, tc.res.getError(), ch.getError()) - } + assert.Nil(t, tc.err) } }) } } +/* func TestTLSALPN01Validate(t *testing.T) { type test struct { srv *httptest.Server @@ -1960,4 +1436,636 @@ func TestDNS01Validate(t *testing.T) { }) } } + +/* +var testOps = ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "", // will get set correctly depending on the "new.." method. + Value: "zap.internal", + }, +} + +func newDNSCh() (Challenge, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + } + return newDNS01Challenge(mockdb, testOps) +} + +func newTLSALPNCh() (Challenge, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + } + return newTLSALPN01Challenge(mockdb, testOps) +} + +func newHTTPCh() (Challenge, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + } + return newHTTP01Challenge(mockdb, testOps) +} + +func newHTTPChWithServer(host string) (Challenge, error) { + mockdb := &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + } + return newHTTP01Challenge(mockdb, ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "", // will get set correctly depending on the "new.." method. + Value: host, + }, + }) +} + +func TestNewHTTP01Challenge(t *testing.T) { + ops := ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "http", + Value: "zap.internal", + }, + } + type test struct { + ops ChallengeOptions + db nosql.DB + err *Error + } + tests := map[string]test{ + "fail/store-error": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + }, + "ok": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ch, err := newHTTP01Challenge(tc.db, tc.ops) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.getAccountID(), ops.AccountID) + assert.Equals(t, ch.getAuthzID(), ops.AuthzID) + assert.Equals(t, ch.getType(), "http-01") + assert.Equals(t, ch.getValue(), "zap.internal") + assert.Equals(t, ch.getStatus(), StatusPending) + + assert.True(t, ch.getValidated().IsZero()) + assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) + + assert.True(t, ch.getID() != "") + assert.True(t, ch.getToken() != "") + } + } + }) + } +} + +func TestNewTLSALPN01Challenge(t *testing.T) { + ops := ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "http", + Value: "zap.internal", + }, + } + type test struct { + ops ChallengeOptions + db nosql.DB + err *Error + } + tests := map[string]test{ + "fail/store-error": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + }, + "ok": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ch, err := newTLSALPN01Challenge(tc.db, tc.ops) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.getAccountID(), ops.AccountID) + assert.Equals(t, ch.getAuthzID(), ops.AuthzID) + assert.Equals(t, ch.getType(), "tls-alpn-01") + assert.Equals(t, ch.getValue(), "zap.internal") + assert.Equals(t, ch.getStatus(), StatusPending) + + assert.True(t, ch.getValidated().IsZero()) + assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) + + assert.True(t, ch.getID() != "") + assert.True(t, ch.getToken() != "") + } + } + }) + } +} + +func TestNewDNS01Challenge(t *testing.T) { + ops := ChallengeOptions{ + AccountID: "accID", + AuthzID: "authzID", + Identifier: Identifier{ + Type: "dns", + Value: "zap.internal", + }, + } + type test struct { + ops ChallengeOptions + db nosql.DB + err *Error + } + tests := map[string]test{ + "fail/store-error": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + }, + "ok": { + ops: ops, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), true, nil + }, + }, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + ch, err := newDNS01Challenge(tc.db, tc.ops) + if err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, ch.getAccountID(), ops.AccountID) + assert.Equals(t, ch.getAuthzID(), ops.AuthzID) + assert.Equals(t, ch.getType(), "dns-01") + assert.Equals(t, ch.getValue(), "zap.internal") + assert.Equals(t, ch.getStatus(), StatusPending) + + assert.True(t, ch.getValidated().IsZero()) + assert.True(t, ch.getCreated().Before(time.Now().UTC().Add(time.Minute))) + assert.True(t, ch.getCreated().After(time.Now().UTC().Add(-1*time.Minute))) + + assert.True(t, ch.getID() != "") + assert.True(t, ch.getToken() != "") + } + } + }) + } +} + +func TestChallengeToACME(t *testing.T) { + dir := newDirectory("ca.smallstep.com", "acme") + + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + _httpCh, ok := httpCh.(*http01Challenge) + assert.Fatal(t, ok) + _httpCh.baseChallenge.Validated = clock.Now() + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + tlsALPNCh, err := newTLSALPNCh() + assert.FatalError(t, err) + + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) + ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) + tests := map[string]challenge{ + "dns": dnsCh, + "http": httpCh, + "tls-alpn": tlsALPNCh, + } + for name, ch := range tests { + t.Run(name, func(t *testing.T) { + ach, err := ch.toACME(ctx, nil, dir) + assert.FatalError(t, err) + + assert.Equals(t, ach.Type, ch.getType()) + assert.Equals(t, ach.Status, ch.getStatus()) + assert.Equals(t, ach.Token, ch.getToken()) + assert.Equals(t, ach.URL, + fmt.Sprintf("%s/acme/%s/challenge/%s", + baseURL.String(), provName, ch.getID())) + assert.Equals(t, ach.ID, ch.getID()) + assert.Equals(t, ach.AuthzID, ch.getAuthzID()) + + if ach.Type == "http-01" { + v, err := time.Parse(time.RFC3339, ach.Validated) + assert.FatalError(t, err) + assert.Equals(t, v.String(), _httpCh.baseChallenge.Validated.String()) + } else { + assert.Equals(t, ach.Validated, "") + } + }) + } +} + +func TestChallengeSave(t *testing.T) { + type test struct { + ch challenge + old challenge + db nosql.DB + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/old-nil/swap-error": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + return test{ + ch: httpCh, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return nil, false, errors.New("force") + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge: force")), + } + }, + "fail/old-nil/swap-false": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + return test{ + ch: httpCh, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + return []byte("foo"), false, nil + }, + }, + err: ServerInternalErr(errors.New("error saving acme challenge; acme challenge has changed since last read")), + } + }, + "ok/old-nil": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + b, err := json.Marshal(httpCh) + assert.FatalError(t, err) + return test{ + ch: httpCh, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, nil) + assert.Equals(t, b, newval) + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, []byte(httpCh.getID()), key) + return []byte("foo"), true, nil + }, + }, + } + }, + "ok/old-not-nil": func(t *testing.T) test { + oldHTTPCh, err := newHTTPCh() + assert.FatalError(t, err) + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + + oldb, err := json.Marshal(oldHTTPCh) + assert.FatalError(t, err) + b, err := json.Marshal(httpCh) + assert.FatalError(t, err) + return test{ + ch: httpCh, + old: oldHTTPCh, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { + assert.Equals(t, old, oldb) + assert.Equals(t, b, newval) + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, []byte(httpCh.getID()), key) + return []byte("foo"), true, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if err := tc.ch.save(tc.db, tc.old); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + assert.Nil(t, tc.err) + } + }) + } +} + +func TestChallengeClone(t *testing.T) { + ch, err := newHTTPCh() + assert.FatalError(t, err) + + clone := ch.clone() + + assert.Equals(t, clone.getID(), ch.getID()) + assert.Equals(t, clone.getAccountID(), ch.getAccountID()) + assert.Equals(t, clone.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, clone.getStatus(), ch.getStatus()) + assert.Equals(t, clone.getToken(), ch.getToken()) + assert.Equals(t, clone.getCreated(), ch.getCreated()) + assert.Equals(t, clone.getValidated(), ch.getValidated()) + + clone.Status = StatusValid + + assert.NotEquals(t, clone.getStatus(), ch.getStatus()) +} + +func TestChallengeUnmarshal(t *testing.T) { + type test struct { + ch challenge + chb []byte + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/nil": func(t *testing.T) test { + return test{ + chb: nil, + err: ServerInternalErr(errors.New("error unmarshaling challenge type: unexpected end of JSON input")), + } + }, + "fail/unexpected-type-http": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + _httpCh, ok := httpCh.(*http01Challenge) + assert.Fatal(t, ok) + _httpCh.baseChallenge.Type = "foo" + b, err := json.Marshal(httpCh) + assert.FatalError(t, err) + return test{ + chb: b, + err: ServerInternalErr(errors.New("unexpected challenge type foo")), + } + }, + "fail/unexpected-type-alpn": func(t *testing.T) test { + tlsALPNCh, err := newTLSALPNCh() + assert.FatalError(t, err) + _tlsALPNCh, ok := tlsALPNCh.(*tlsALPN01Challenge) + assert.Fatal(t, ok) + _tlsALPNCh.baseChallenge.Type = "foo" + b, err := json.Marshal(tlsALPNCh) + assert.FatalError(t, err) + return test{ + chb: b, + err: ServerInternalErr(errors.New("unexpected challenge type foo")), + } + }, + "fail/unexpected-type-dns": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + _dnsCh, ok := dnsCh.(*dns01Challenge) + assert.Fatal(t, ok) + _dnsCh.baseChallenge.Type = "foo" + b, err := json.Marshal(dnsCh) + assert.FatalError(t, err) + return test{ + chb: b, + err: ServerInternalErr(errors.New("unexpected challenge type foo")), + } + }, + "ok/dns": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + b, err := json.Marshal(dnsCh) + assert.FatalError(t, err) + return test{ + ch: dnsCh, + chb: b, + } + }, + "ok/http": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + b, err := json.Marshal(httpCh) + assert.FatalError(t, err) + return test{ + ch: httpCh, + chb: b, + } + }, + "ok/alpn": func(t *testing.T) test { + tlsALPNCh, err := newTLSALPNCh() + assert.FatalError(t, err) + b, err := json.Marshal(tlsALPNCh) + assert.FatalError(t, err) + return test{ + ch: tlsALPNCh, + chb: b, + } + }, + "ok/err": func(t *testing.T) test { + httpCh, err := newHTTPCh() + assert.FatalError(t, err) + _httpCh, ok := httpCh.(*http01Challenge) + assert.Fatal(t, ok) + _httpCh.baseChallenge.Error = ServerInternalErr(errors.New("force")).ToACME() + b, err := json.Marshal(httpCh) + assert.FatalError(t, err) + return test{ + ch: httpCh, + chb: b, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if ch, err := unmarshalChallenge(tc.chb); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.ch.getID(), ch.getID()) + assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) + assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) + assert.Equals(t, tc.ch.getToken(), ch.getToken()) + assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) + assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) + } + } + }) + } +} +func TestGetChallenge(t *testing.T) { + type test struct { + id string + db nosql.DB + ch challenge + err *Error + } + tests := map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + return test{ + ch: dnsCh, + id: dnsCh.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, + err: MalformedErr(errors.Errorf("challenge %s not found: not found", dnsCh.getID())), + } + }, + "fail/db-error": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + return test{ + ch: dnsCh, + id: dnsCh.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("force") + }, + }, + err: ServerInternalErr(errors.Errorf("error loading challenge %s: force", dnsCh.getID())), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + _dnsCh, ok := dnsCh.(*dns01Challenge) + assert.Fatal(t, ok) + _dnsCh.baseChallenge.Type = "foo" + b, err := json.Marshal(dnsCh) + assert.FatalError(t, err) + return test{ + ch: dnsCh, + id: dnsCh.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(dnsCh.getID())) + return b, nil + }, + }, + err: ServerInternalErr(errors.New("unexpected challenge type foo")), + } + }, + "ok": func(t *testing.T) test { + dnsCh, err := newDNSCh() + assert.FatalError(t, err) + b, err := json.Marshal(dnsCh) + assert.FatalError(t, err) + return test{ + ch: dnsCh, + id: dnsCh.getID(), + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, key, []byte(dnsCh.getID())) + return b, nil + }, + }, + } + }, + } + for name, run := range tests { + t.Run(name, func(t *testing.T) { + tc := run(t) + if ch, err := getChallenge(tc.db, tc.id); err != nil { + if assert.NotNil(t, tc.err) { + ae, ok := err.(*Error) + assert.True(t, ok) + assert.HasPrefix(t, ae.Error(), tc.err.Error()) + assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) + assert.Equals(t, ae.Type, tc.err.Type) + } + } else { + if assert.Nil(t, tc.err) { + assert.Equals(t, tc.ch.getID(), ch.getID()) + assert.Equals(t, tc.ch.getAccountID(), ch.getAccountID()) + assert.Equals(t, tc.ch.getAuthzID(), ch.getAuthzID()) + assert.Equals(t, tc.ch.getStatus(), ch.getStatus()) + assert.Equals(t, tc.ch.getToken(), ch.getToken()) + assert.Equals(t, tc.ch.getCreated(), ch.getCreated()) + assert.Equals(t, tc.ch.getValidated(), ch.getValidated()) + } + } + }) + } +} */ diff --git a/acme/db/nosql/authz.go b/acme/db/nosql/authz.go index 2ea1bb69..449a9276 100644 --- a/acme/db/nosql/authz.go +++ b/acme/db/nosql/authz.go @@ -14,16 +14,16 @@ var defaultExpiryDuration = time.Hour * 24 // dbAuthz is the base authz type that others build from. type dbAuthz struct { - ID string `json:"id"` - AccountID string `json:"accountID"` - Identifier acme.Identifier `json:"identifier"` - Status acme.Status `json:"status"` - ExpiresAt time.Time `json:"expiresAt"` - Challenges []string `json:"challenges"` - Wildcard bool `json:"wildcard"` - CreatedAt time.Time `json:"createdAt"` - Error *acme.Error `json:"error"` - Token string `json:"token"` + ID string `json:"id"` + AccountID string `json:"accountID"` + Identifier acme.Identifier `json:"identifier"` + Status acme.Status `json:"status"` + ExpiresAt time.Time `json:"expiresAt"` + ChallengeIDs []string `json:"challengeIDs"` + Wildcard bool `json:"wildcard"` + CreatedAt time.Time `json:"createdAt"` + Error *acme.Error `json:"error"` + Token string `json:"token"` } func (ba *dbAuthz) clone() *dbAuthz { @@ -55,8 +55,8 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat if err != nil { return nil, err } - var chs = make([]*acme.Challenge, len(dbaz.Challenges)) - for i, chID := range dbaz.Challenges { + var chs = make([]*acme.Challenge, len(dbaz.ChallengeIDs)) + for i, chID := range dbaz.ChallengeIDs { chs[i], err = db.GetChallenge(ctx, chID, id) if err != nil { return nil, err @@ -91,15 +91,15 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) e now := clock.Now() dbaz := &dbAuthz{ - ID: az.ID, - AccountID: az.AccountID, - Status: az.Status, - CreatedAt: now, - ExpiresAt: az.ExpiresAt, - Identifier: az.Identifier, - Challenges: chIDs, - Token: az.Token, - Wildcard: az.Wildcard, + ID: az.ID, + AccountID: az.AccountID, + Status: az.Status, + CreatedAt: now, + ExpiresAt: az.ExpiresAt, + Identifier: az.Identifier, + ChallengeIDs: chIDs, + Token: az.Token, + Wildcard: az.Wildcard, } return db.save(ctx, az.ID, dbaz, nil, "authz", authzTable) diff --git a/acme/db/nosql/authz_test.go b/acme/db/nosql/authz_test.go index 825c4648..0c2cec50 100644 --- a/acme/db/nosql/authz_test.go +++ b/acme/db/nosql/authz_test.go @@ -71,13 +71,13 @@ func TestDB_getDBAuthz(t *testing.T) { Type: "dns", Value: "test.ca.smallstep.com", }, - Status: acme.StatusPending, - Token: "token", - CreatedAt: now, - ExpiresAt: now.Add(5 * time.Minute), - Error: acme.NewErrorISE("force"), - Challenges: []string{"foo", "bar"}, - Wildcard: true, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) @@ -174,13 +174,13 @@ func TestDB_GetAuthorization(t *testing.T) { Type: "dns", Value: "test.ca.smallstep.com", }, - Status: acme.StatusPending, - Token: "token", - CreatedAt: now, - ExpiresAt: now.Add(5 * time.Minute), - Error: acme.NewErrorISE("force"), - Challenges: []string{"foo", "bar"}, - Wildcard: true, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) @@ -212,13 +212,13 @@ func TestDB_GetAuthorization(t *testing.T) { Type: "dns", Value: "test.ca.smallstep.com", }, - Status: acme.StatusPending, - Token: "token", - CreatedAt: now, - ExpiresAt: now.Add(5 * time.Minute), - Error: acme.NewErrorISE("force"), - Challenges: []string{"foo", "bar"}, - Wildcard: true, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) @@ -250,13 +250,13 @@ func TestDB_GetAuthorization(t *testing.T) { Type: "dns", Value: "test.ca.smallstep.com", }, - Status: acme.StatusPending, - Token: "token", - CreatedAt: now, - ExpiresAt: now.Add(5 * time.Minute), - Error: acme.NewErrorISE("force"), - Challenges: []string{"foo", "bar"}, - Wildcard: true, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + Error: acme.NewErrorISE("force"), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) @@ -374,7 +374,7 @@ func TestDB_CreateAuthorization(t *testing.T) { }) assert.Equals(t, dbaz.Status, az.Status) assert.Equals(t, dbaz.Token, az.Token) - assert.Equals(t, dbaz.Challenges, []string{"foo", "bar"}) + assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"}) assert.Equals(t, dbaz.Wildcard, az.Wildcard) assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt) assert.Nil(t, dbaz.Error) @@ -428,7 +428,7 @@ func TestDB_CreateAuthorization(t *testing.T) { }) assert.Equals(t, dbaz.Status, az.Status) assert.Equals(t, dbaz.Token, az.Token) - assert.Equals(t, dbaz.Challenges, []string{"foo", "bar"}) + assert.Equals(t, dbaz.ChallengeIDs, []string{"foo", "bar"}) assert.Equals(t, dbaz.Wildcard, az.Wildcard) assert.Equals(t, dbaz.ExpiresAt, az.ExpiresAt) assert.Nil(t, dbaz.Error) @@ -469,12 +469,12 @@ func TestDB_UpdateAuthorization(t *testing.T) { Type: "dns", Value: "test.ca.smallstep.com", }, - Status: acme.StatusPending, - Token: "token", - CreatedAt: now, - ExpiresAt: now.Add(5 * time.Minute), - Challenges: []string{"foo", "bar"}, - Wildcard: true, + Status: acme.StatusPending, + Token: "token", + CreatedAt: now, + ExpiresAt: now.Add(5 * time.Minute), + ChallengeIDs: []string{"foo", "bar"}, + Wildcard: true, } b, err := json.Marshal(dbaz) assert.FatalError(t, err) @@ -530,7 +530,7 @@ func TestDB_UpdateAuthorization(t *testing.T) { assert.Equals(t, dbNew.Identifier, dbaz.Identifier) assert.Equals(t, dbNew.Status, acme.StatusValid) assert.Equals(t, dbNew.Token, dbaz.Token) - assert.Equals(t, dbNew.Challenges, dbaz.Challenges) + assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs) assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard) assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt) assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt) @@ -580,7 +580,7 @@ func TestDB_UpdateAuthorization(t *testing.T) { assert.Equals(t, dbNew.Identifier, dbaz.Identifier) assert.Equals(t, dbNew.Status, acme.StatusValid) assert.Equals(t, dbNew.Token, dbaz.Token) - assert.Equals(t, dbNew.Challenges, dbaz.Challenges) + assert.Equals(t, dbNew.ChallengeIDs, dbaz.ChallengeIDs) assert.Equals(t, dbNew.Wildcard, dbaz.Wildcard) assert.Equals(t, dbNew.CreatedAt, dbaz.CreatedAt) assert.Equals(t, dbNew.ExpiresAt, dbaz.ExpiresAt) diff --git a/acme/db/nosql/nosql.go b/acme/db/nosql/nosql.go index bcb118d8..b8f79edc 100644 --- a/acme/db/nosql/nosql.go +++ b/acme/db/nosql/nosql.go @@ -42,9 +42,17 @@ func New(db nosqlDB.DB) (*DB, error) { // save writes the new data to the database, overwriting the old data if it // existed. func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { - newB, err := json.Marshal(nu) - if err != nil { - return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu) + var ( + err error + newB []byte + ) + if nu == nil { + newB = nil + } else { + newB, err = json.Marshal(nu) + if err != nil { + return errors.Wrapf(err, "error marshaling acme type: %s, value: %v", typ, nu) + } } var oldB []byte if old == nil { diff --git a/acme/db/nosql/nosql_test.go b/acme/db/nosql/nosql_test.go index b7a91a2f..7fd21c50 100644 --- a/acme/db/nosql/nosql_test.go +++ b/acme/db/nosql/nosql_test.go @@ -110,6 +110,19 @@ func TestDB_save(t *testing.T) { }, }, }, + "ok/nils": test{ + nu: nil, + old: nil, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, challengeTable) + assert.Equals(t, string(key), "id") + assert.Equals(t, old, nil) + assert.Equals(t, nu, nil) + return nu, true, nil + }, + }, + }, } for name, tc := range tests { t.Run(name, func(t *testing.T) { diff --git a/acme/db/nosql/order.go b/acme/db/nosql/order.go index a64316a6..c8fe53e1 100644 --- a/acme/db/nosql/order.go +++ b/acme/db/nosql/order.go @@ -127,15 +127,17 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st defer ordersByAccountMux.Unlock() b, err := db.db.Get(ordersByAccountIDTable, []byte(accID)) + var ( + oldOids []string + ) if err != nil { - if nosql.IsErrNotFound(err) { - return []string{}, nil + if !nosql.IsErrNotFound(err) { + return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID) + } + } else { + if err := json.Unmarshal(b, &oldOids); err != nil { + return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID) } - return nil, errors.Wrapf(err, "error loading orderIDs for account %s", accID) - } - var oids []string - if err := json.Unmarshal(b, &oids); err != nil { - return nil, errors.Wrapf(err, "error unmarshaling orderIDs for account %s", accID) } // Remove any order that is not in PENDING state and update the stored list @@ -145,7 +147,7 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st // The server SHOULD include pending orders and SHOULD NOT include orders // that are invalid in the array of URLs. pendOids := []string{} - for _, oid := range oids { + for _, oid := range oldOids { o, err := db.GetOrder(ctx, oid) if err != nil { return nil, acme.WrapErrorISE(err, "error loading order %s for account %s", oid, accID) @@ -158,15 +160,27 @@ func (db *DB) updateAddOrderIDs(ctx context.Context, accID string, addOids ...st } } pendOids = append(pendOids, addOids...) - if len(oids) == 0 { - oids = nil + var ( + _old interface{} = oldOids + _new interface{} = pendOids + ) + switch { + case len(oldOids) == 0 && len(pendOids) == 0: + // If list has not changed from empty, then no need to write the DB. + return []string{}, nil + case len(oldOids) == 0: + _old = nil + case len(pendOids) == 0: + _new = nil } - if err = db.save(ctx, accID, pendOids, oids, "orderIDsByAccountID", ordersByAccountIDTable); err != nil { + if err = db.save(ctx, accID, _new, _old, "orderIDsByAccountID", ordersByAccountIDTable); err != nil { // Delete all orders that may have been previously stored if orderIDsByAccountID update fails. for _, oid := range addOids { + // Ignore error from delete -- we tried our best. + // TODO when we have logging w/ request ID tracking, logging this error. db.db.Del(orderTable, []byte(oid)) } - return nil, errors.Wrap(err, "error saving OrderIDsByAccountID index") + return nil, errors.Wrapf(err, "error saving orderIDs index for account %s", accID) } return pendOids, nil } diff --git a/acme/db/nosql/order_test.go b/acme/db/nosql/order_test.go index 8ce7ac79..3636837c 100644 --- a/acme/db/nosql/order_test.go +++ b/acme/db/nosql/order_test.go @@ -3,6 +3,7 @@ package nosql import ( "context" "encoding/json" + "reflect" "testing" "time" @@ -511,27 +512,39 @@ func TestDB_CreateOrder(t *testing.T) { return nil, nosqldb.ErrNotFound }, MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { - *idptr = string(key) - assert.Equals(t, string(bucket), string(orderTable)) - assert.Equals(t, string(key), o.ID) - assert.Equals(t, old, nil) + switch string(bucket) { + case string(ordersByAccountIDTable): + b, err := json.Marshal([]string{o.ID}) + assert.FatalError(t, err) + assert.Equals(t, string(key), "accID") + assert.Equals(t, old, nil) + assert.Equals(t, nu, b) + return nu, true, nil + case string(orderTable): + *idptr = string(key) + assert.Equals(t, string(key), o.ID) + assert.Equals(t, old, nil) - dbo := new(dbOrder) - assert.FatalError(t, json.Unmarshal(nu, dbo)) - assert.Equals(t, dbo.ID, o.ID) - assert.Equals(t, dbo.AccountID, o.AccountID) - assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) - assert.Equals(t, dbo.CertificateID, "") - assert.Equals(t, dbo.Status, o.Status) - assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) - assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) - assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) - assert.Equals(t, dbo.NotBefore, o.NotBefore) - assert.Equals(t, dbo.NotAfter, o.NotAfter) - assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) - assert.Equals(t, dbo.Identifiers, o.Identifiers) - assert.Equals(t, dbo.Error, nil) - return nu, true, nil + dbo := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, dbo)) + assert.Equals(t, dbo.ID, o.ID) + assert.Equals(t, dbo.AccountID, o.AccountID) + assert.Equals(t, dbo.ProvisionerID, o.ProvisionerID) + assert.Equals(t, dbo.CertificateID, "") + assert.Equals(t, dbo.Status, o.Status) + assert.True(t, dbo.CreatedAt.Add(-time.Minute).Before(now)) + assert.True(t, dbo.CreatedAt.Add(time.Minute).After(now)) + assert.Equals(t, dbo.ExpiresAt, o.ExpiresAt) + assert.Equals(t, dbo.NotBefore, o.NotBefore) + assert.Equals(t, dbo.NotAfter, o.NotAfter) + assert.Equals(t, dbo.AuthorizationIDs, o.AuthorizationIDs) + assert.Equals(t, dbo.Identifiers, o.Identifiers) + assert.Equals(t, dbo.Error, nil) + return nu, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } }, }, o: o, @@ -555,3 +568,434 @@ func TestDB_CreateOrder(t *testing.T) { }) } } + +func TestDB_updateAddOrderIDs(t *testing.T) { + accID := "accID" + type test struct { + db nosql.DB + err error + acmeErr *acme.Error + addOids []string + res []string + } + var tests = map[string]func(t *testing.T) test{ + "fail/db.Get-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + return nil, errors.New("force") + }, + }, + err: errors.Errorf("error loading orderIDs for account %s", accID), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + return []byte("foo"), nil + }, + }, + err: errors.Errorf("error unmarshaling orderIDs for account %s", accID), + } + }, + "fail/db.Get-order-error": func(t *testing.T) test { + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + b, err := json.Marshal([]string{"foo", "bar"}) + assert.FatalError(t, err) + return b, nil + case string(orderTable): + assert.Equals(t, key, []byte("foo")) + return nil, errors.New("force") + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + }, + acmeErr: acme.NewErrorISE("error loading order foo for account accID: error loading order foo: force"), + } + }, + "fail/update-order-status-error": func(t *testing.T) test { + expiry := clock.Now().Add(-5 * time.Minute) + ofoo := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bfoo, err := json.Marshal(ofoo) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + b, err := json.Marshal([]string{"foo", "bar"}) + assert.FatalError(t, err) + return b, nil + case string(orderTable): + assert.Equals(t, key, []byte("foo")) + return bfoo, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("foo")) + assert.Equals(t, old, bfoo) + + newdbo := new(dbOrder) + assert.FatalError(t, json.Unmarshal(nu, newdbo)) + assert.Equals(t, newdbo.ID, "foo") + assert.Equals(t, newdbo.Status, acme.StatusInvalid) + assert.Equals(t, newdbo.ExpiresAt, expiry) + assert.Equals(t, newdbo.Error.Error(), acme.NewError(acme.ErrorMalformedType, "order has expired").Error()) + return nil, false, errors.New("force") + }, + }, + acmeErr: acme.NewErrorISE("error updating order foo for account accID: error saving acme order: force"), + } + }, + "fail/db.save-order-error": func(t *testing.T) test { + addOids := []string{"foo", "bar"} + b, err := json.Marshal(addOids) + assert.FatalError(t, err) + delCount := 0 + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + return nil, nosqldb.ErrNotFound + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, ordersByAccountIDTable) + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, nil) + assert.Equals(t, nu, b) + return nil, false, errors.New("force") + }, + MDel: func(bucket, key []byte) error { + delCount++ + switch delCount { + case 1: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("foo")) + return nil + case 2: + assert.Equals(t, bucket, orderTable) + assert.Equals(t, key, []byte("bar")) + return nil + default: + assert.FatalError(t, errors.New("delete should only be called twice")) + return errors.New("force") + } + }, + }, + addOids: addOids, + err: errors.Errorf("error saving orderIDs index for account %s", accID), + } + }, + "ok/all-old-not-pending": func(t *testing.T) test { + oldOids := []string{"foo", "bar"} + bOldOids, err := json.Marshal(oldOids) + assert.FatalError(t, err) + expiry := clock.Now().Add(-5 * time.Minute) + ofoo := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bfoo, err := json.Marshal(ofoo) + assert.FatalError(t, err) + obar := &dbOrder{ + ID: "bar", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bbar, err := json.Marshal(obar) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + return bOldOids, nil + case string(orderTable): + switch string(key) { + case "foo": + assert.Equals(t, key, []byte("foo")) + return bfoo, nil + case "bar": + assert.Equals(t, key, []byte("bar")) + return bbar, nil + default: + assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) + return nil, errors.New("force") + } + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(orderTable): + return nil, true, nil + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, bOldOids) + assert.Equals(t, nu, nil) + return nil, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + res: []string{}, + } + }, + "ok/old-and-new": func(t *testing.T) test { + oldOids := []string{"foo", "bar"} + bOldOids, err := json.Marshal(oldOids) + assert.FatalError(t, err) + addOids := []string{"zap", "zar"} + bAddOids, err := json.Marshal(addOids) + assert.FatalError(t, err) + expiry := clock.Now().Add(-5 * time.Minute) + ofoo := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bfoo, err := json.Marshal(ofoo) + assert.FatalError(t, err) + obar := &dbOrder{ + ID: "bar", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bbar, err := json.Marshal(obar) + assert.FatalError(t, err) + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(ordersByAccountIDTable): + return bOldOids, nil + case string(orderTable): + switch string(key) { + case "foo": + assert.Equals(t, key, []byte("foo")) + return bfoo, nil + case "bar": + assert.Equals(t, key, []byte("bar")) + return bbar, nil + default: + assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) + return nil, errors.New("force") + } + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(orderTable): + return nil, true, nil + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, bOldOids) + assert.Equals(t, nu, bAddOids) + return nil, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + addOids: addOids, + res: addOids, + } + }, + "ok/old-and-new-2": func(t *testing.T) test { + oldOids := []string{"foo", "bar", "baz"} + bOldOids, err := json.Marshal(oldOids) + assert.FatalError(t, err) + addOids := []string{"zap", "zar"} + now := clock.Now() + min5 := now.Add(5 * time.Minute) + expiry := now.Add(-5 * time.Minute) + + o1 := &dbOrder{ + ID: "foo", + Status: acme.StatusPending, + ExpiresAt: min5, + AuthorizationIDs: []string{"a"}, + } + bo1, err := json.Marshal(o1) + assert.FatalError(t, err) + o2 := &dbOrder{ + ID: "bar", + Status: acme.StatusPending, + ExpiresAt: expiry, + } + bo2, err := json.Marshal(o2) + assert.FatalError(t, err) + o3 := &dbOrder{ + ID: "baz", + Status: acme.StatusPending, + ExpiresAt: min5, + AuthorizationIDs: []string{"b"}, + } + bo3, err := json.Marshal(o3) + assert.FatalError(t, err) + + az1 := &dbAuthz{ + ID: "a", + Status: acme.StatusPending, + ExpiresAt: min5, + ChallengeIDs: []string{"aa"}, + } + baz1, err := json.Marshal(az1) + assert.FatalError(t, err) + az2 := &dbAuthz{ + ID: "b", + Status: acme.StatusPending, + ExpiresAt: min5, + ChallengeIDs: []string{"bb"}, + } + baz2, err := json.Marshal(az2) + assert.FatalError(t, err) + + ch1 := &dbChallenge{ + ID: "aa", + Status: acme.StatusPending, + } + bch1, err := json.Marshal(ch1) + assert.FatalError(t, err) + ch2 := &dbChallenge{ + ID: "bb", + Status: acme.StatusPending, + } + bch2, err := json.Marshal(ch2) + assert.FatalError(t, err) + + newOids := append([]string{"foo", "baz"}, addOids...) + bNewOids, err := json.Marshal(newOids) + assert.FatalError(t, err) + + return test{ + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + switch string(bucket) { + case string(authzTable): + switch string(key) { + case "a": + return baz1, nil + case "b": + return baz2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected authz key %s", string(key))) + return nil, errors.New("force") + } + case string(challengeTable): + switch string(key) { + case "aa": + return bch1, nil + case "bb": + return bch2, nil + default: + assert.FatalError(t, errors.Errorf("unexpected challenge key %s", string(key))) + return nil, errors.New("force") + } + case string(ordersByAccountIDTable): + return bOldOids, nil + case string(orderTable): + switch string(key) { + case "foo": + return bo1, nil + case "bar": + return bo2, nil + case "baz": + return bo3, nil + default: + assert.FatalError(t, errors.Errorf("unexpected key %s", string(key))) + return nil, errors.New("force") + } + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, errors.New("force") + } + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + switch string(bucket) { + case string(orderTable): + return nil, true, nil + case string(ordersByAccountIDTable): + assert.Equals(t, key, []byte(accID)) + assert.Equals(t, old, bOldOids) + assert.Equals(t, nu, bNewOids) + return nil, true, nil + default: + assert.FatalError(t, errors.Errorf("unexpected bucket %s", string(bucket))) + return nil, false, errors.New("force") + } + }, + }, + addOids: addOids, + res: newOids, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + db := DB{db: tc.db} + var ( + res []string + err error + ) + if tc.addOids == nil { + res, err = db.updateAddOrderIDs(context.Background(), accID) + } else { + res, err = db.updateAddOrderIDs(context.Background(), accID, tc.addOids...) + } + + if err != nil { + switch k := err.(type) { + case *acme.Error: + if assert.NotNil(t, tc.acmeErr) { + assert.Equals(t, k.Type, tc.acmeErr.Type) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + assert.Equals(t, k.Status, tc.acmeErr.Status) + assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error()) + assert.Equals(t, k.Detail, tc.acmeErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else { + if assert.Nil(t, tc.err) { + assert.True(t, reflect.DeepEqual(res, tc.res)) + } + } + }) + } +}