diff --git a/acme/api/middleware.go b/acme/api/middleware.go index 55bc96bf..99980f50 100644 --- a/acme/api/middleware.go +++ b/acme/api/middleware.go @@ -378,6 +378,9 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { // canExtractJWKFrom checks if the JWS has a JWK that can be extracted func canExtractJWKFrom(jws *jose.JSONWebSignature) bool { + if jws == nil { + return false + } if len(jws.Signatures) == 0 { return false } diff --git a/acme/api/middleware_test.go b/acme/api/middleware_test.go index e8d22d53..1cc93de7 100644 --- a/acme/api/middleware_test.go +++ b/acme/api/middleware_test.go @@ -1473,3 +1473,191 @@ func TestHandler_validateJWS(t *testing.T) { }) } } + +func Test_canExtractJWKFrom(t *testing.T) { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + type args struct { + jws *jose.JSONWebSignature + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "no-jws", + args: args{ + jws: nil, + }, + want: false, + }, + { + name: "no-signatures", + args: args{ + jws: &jose.JSONWebSignature{ + Signatures: []jose.Signature{}, + }, + }, + want: false, + }, + { + name: "no-jwk", + args: args{ + jws: &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{}, + }, + }, + }, + }, + want: false, + }, + { + name: "ok", + args: args{ + jws: &jose.JSONWebSignature{ + Signatures: []jose.Signature{ + { + Protected: jose.Header{ + JSONWebKey: jwk, + }, + }, + }, + }, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := canExtractJWKFrom(tt.args.jws); got != tt.want { + t.Errorf("canExtractJWKFrom() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHandler_extractOrLookupJWK(t *testing.T) { + u := "https://ca.smallstep.com/acme/account" + type test struct { + db acme.DB + linker Linker + statusCode int + ctx context.Context + err *acme.Error + next func(w http.ResponseWriter, r *http.Request) + } + var tests = map[string]func(t *testing.T) test{ + "ok/extract": func(t *testing.T) test { + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + kid, err := jwk.Thumbprint(crypto.SHA256) + assert.FatalError(t, err) + pub := jwk.Public() + pub.KeyID = base64.RawURLEncoding.EncodeToString(kid) + so := new(jose.SignerOptions) + so.WithHeader("jwk", pub) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + signed, err := signer.Sign([]byte("foo")) + assert.FatalError(t, err) + raw, err := signed.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + assert.FatalError(t, err) + acc := &acme.Account{Status: "valid"} + return test{ + linker: NewLinker("dns", "acme"), + db: &acme.MockDB{ + MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { + assert.Equals(t, kid, pub.KeyID) + return acc, nil + }, + }, + ctx: context.WithValue(context.Background(), jwsContextKey, parsedJWS), + statusCode: 200, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + } + }, + "ok/lookup": func(t *testing.T) test { + prov := newProv() + provName := url.PathEscape(prov.GetName()) + baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} + jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) + assert.FatalError(t, err) + accID := "accID" + prefix := fmt.Sprintf("%s/acme/%s/account/", baseURL, provName) + so := new(jose.SignerOptions) + so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID)) + signer, err := jose.NewSigner(jose.SigningKey{ + Algorithm: jose.SignatureAlgorithm(jwk.Algorithm), + Key: jwk.Key, + }, so) + assert.FatalError(t, err) + jws, err := signer.Sign([]byte("baz")) + assert.FatalError(t, err) + raw, err := jws.CompactSerialize() + assert.FatalError(t, err) + parsedJWS, err := jose.ParseJWS(raw) + assert.FatalError(t, err) + assert.FatalError(t, err) + //acc := &acme.Account{Status: "valid", Key: jwk} + acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"} + ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) + return test{ + linker: NewLinker("test.ca.smallstep.com", "acme"), + db: &acme.MockDB{ + MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { + assert.Equals(t, accID, acc.ID) + return acc, nil + }, + }, + ctx: ctx, + statusCode: 200, + next: func(w http.ResponseWriter, r *http.Request) { + w.Write(testBody) + }, + } + }, + } + for name, prep := range tests { + tc := prep(t) + t.Run(name, func(t *testing.T) { + h := &Handler{db: tc.db, linker: tc.linker} + req := httptest.NewRequest("GET", u, nil) + req = req.WithContext(tc.ctx) + w := httptest.NewRecorder() + h.extractOrLookupJWK(tc.next)(w, req) + res := w.Result() + + assert.Equals(t, res.StatusCode, tc.statusCode) + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + assert.FatalError(t, err) + + if res.StatusCode >= 400 && assert.NotNil(t, tc.err) { + var ae acme.Error + assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae)) + + assert.Equals(t, ae.Type, tc.err.Type) + assert.Equals(t, ae.Detail, tc.err.Detail) + assert.Equals(t, ae.Identifier, tc.err.Identifier) + assert.Equals(t, ae.Subproblems, tc.err.Subproblems) + assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"}) + } else { + assert.Equals(t, bytes.TrimSpace(body), testBody) + } + }) + } +}