Add tests for extractOrLookupJWK middleware

This commit is contained in:
Herman Slatman 2021-11-12 16:37:44 +01:00
parent 3151255a25
commit c7a9c13060
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
2 changed files with 191 additions and 0 deletions

View file

@ -378,6 +378,9 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
// canExtractJWKFrom checks if the JWS has a JWK that can be extracted
func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
if jws == nil {
return false
}
if len(jws.Signatures) == 0 {
return false
}

View file

@ -1473,3 +1473,191 @@ func TestHandler_validateJWS(t *testing.T) {
})
}
}
func Test_canExtractJWKFrom(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
type args struct {
jws *jose.JSONWebSignature
}
tests := []struct {
name string
args args
want bool
}{
{
name: "no-jws",
args: args{
jws: nil,
},
want: false,
},
{
name: "no-signatures",
args: args{
jws: &jose.JSONWebSignature{
Signatures: []jose.Signature{},
},
},
want: false,
},
{
name: "no-jwk",
args: args{
jws: &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{},
},
},
},
},
want: false,
},
{
name: "ok",
args: args{
jws: &jose.JSONWebSignature{
Signatures: []jose.Signature{
{
Protected: jose.Header{
JSONWebKey: jwk,
},
},
},
},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := canExtractJWKFrom(tt.args.jws); got != tt.want {
t.Errorf("canExtractJWKFrom() = %v, want %v", got, tt.want)
}
})
}
}
func TestHandler_extractOrLookupJWK(t *testing.T) {
u := "https://ca.smallstep.com/acme/account"
type test struct {
db acme.DB
linker Linker
statusCode int
ctx context.Context
err *acme.Error
next func(w http.ResponseWriter, r *http.Request)
}
var tests = map[string]func(t *testing.T) test{
"ok/extract": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
kid, err := jwk.Thumbprint(crypto.SHA256)
assert.FatalError(t, err)
pub := jwk.Public()
pub.KeyID = base64.RawURLEncoding.EncodeToString(kid)
so := new(jose.SignerOptions)
so.WithHeader("jwk", pub)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk.Key,
}, so)
assert.FatalError(t, err)
signed, err := signer.Sign([]byte("foo"))
assert.FatalError(t, err)
raw, err := signed.CompactSerialize()
assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
assert.FatalError(t, err)
acc := &acme.Account{Status: "valid"}
return test{
linker: NewLinker("dns", "acme"),
db: &acme.MockDB{
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
assert.Equals(t, kid, pub.KeyID)
return acc, nil
},
},
ctx: context.WithValue(context.Background(), jwsContextKey, parsedJWS),
statusCode: 200,
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
}
},
"ok/lookup": func(t *testing.T) test {
prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
accID := "accID"
prefix := fmt.Sprintf("%s/acme/%s/account/", baseURL, provName)
so := new(jose.SignerOptions)
so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID))
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk.Key,
}, so)
assert.FatalError(t, err)
jws, err := signer.Sign([]byte("baz"))
assert.FatalError(t, err)
raw, err := jws.CompactSerialize()
assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
assert.FatalError(t, err)
//acc := &acme.Account{Status: "valid", Key: jwk}
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
linker: NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, acc.ID)
return acc, nil
},
},
ctx: ctx,
statusCode: 200,
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(testBody)
},
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.extractOrLookupJWK(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
var ae acme.Error
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
assert.Equals(t, ae.Type, tc.err.Type)
assert.Equals(t, ae.Detail, tc.err.Detail)
assert.Equals(t, ae.Identifier, tc.err.Identifier)
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
} else {
assert.Equals(t, bytes.TrimSpace(body), testBody)
}
})
}
}