Add tests for extractOrLookupJWK middleware
This commit is contained in:
parent
3151255a25
commit
c7a9c13060
2 changed files with 191 additions and 0 deletions
|
@ -378,6 +378,9 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
|
||||||
|
|
||||||
// canExtractJWKFrom checks if the JWS has a JWK that can be extracted
|
// canExtractJWKFrom checks if the JWS has a JWK that can be extracted
|
||||||
func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
|
func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
|
||||||
|
if jws == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if len(jws.Signatures) == 0 {
|
if len(jws.Signatures) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -1473,3 +1473,191 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_canExtractJWKFrom(t *testing.T) {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
type args struct {
|
||||||
|
jws *jose.JSONWebSignature
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no-jws",
|
||||||
|
args: args{
|
||||||
|
jws: nil,
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-signatures",
|
||||||
|
args: args{
|
||||||
|
jws: &jose.JSONWebSignature{
|
||||||
|
Signatures: []jose.Signature{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no-jwk",
|
||||||
|
args: args{
|
||||||
|
jws: &jose.JSONWebSignature{
|
||||||
|
Signatures: []jose.Signature{
|
||||||
|
{
|
||||||
|
Protected: jose.Header{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
args: args{
|
||||||
|
jws: &jose.JSONWebSignature{
|
||||||
|
Signatures: []jose.Signature{
|
||||||
|
{
|
||||||
|
Protected: jose.Header{
|
||||||
|
JSONWebKey: jwk,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := canExtractJWKFrom(tt.args.jws); got != tt.want {
|
||||||
|
t.Errorf("canExtractJWKFrom() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandler_extractOrLookupJWK(t *testing.T) {
|
||||||
|
u := "https://ca.smallstep.com/acme/account"
|
||||||
|
type test struct {
|
||||||
|
db acme.DB
|
||||||
|
linker Linker
|
||||||
|
statusCode int
|
||||||
|
ctx context.Context
|
||||||
|
err *acme.Error
|
||||||
|
next func(w http.ResponseWriter, r *http.Request)
|
||||||
|
}
|
||||||
|
var tests = map[string]func(t *testing.T) test{
|
||||||
|
"ok/extract": func(t *testing.T) test {
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
pub := jwk.Public()
|
||||||
|
pub.KeyID = base64.RawURLEncoding.EncodeToString(kid)
|
||||||
|
so := new(jose.SignerOptions)
|
||||||
|
so.WithHeader("jwk", pub)
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{
|
||||||
|
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
||||||
|
Key: jwk.Key,
|
||||||
|
}, so)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
signed, err := signer.Sign([]byte("foo"))
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
raw, err := signed.CompactSerialize()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
parsedJWS, err := jose.ParseJWS(raw)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
acc := &acme.Account{Status: "valid"}
|
||||||
|
return test{
|
||||||
|
linker: NewLinker("dns", "acme"),
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, kid, pub.KeyID)
|
||||||
|
return acc, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: context.WithValue(context.Background(), jwsContextKey, parsedJWS),
|
||||||
|
statusCode: 200,
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/lookup": func(t *testing.T) test {
|
||||||
|
prov := newProv()
|
||||||
|
provName := url.PathEscape(prov.GetName())
|
||||||
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
accID := "accID"
|
||||||
|
prefix := fmt.Sprintf("%s/acme/%s/account/", baseURL, provName)
|
||||||
|
so := new(jose.SignerOptions)
|
||||||
|
so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID))
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{
|
||||||
|
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
||||||
|
Key: jwk.Key,
|
||||||
|
}, so)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
jws, err := signer.Sign([]byte("baz"))
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
raw, err := jws.CompactSerialize()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
parsedJWS, err := jose.ParseJWS(raw)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
//acc := &acme.Account{Status: "valid", Key: jwk}
|
||||||
|
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
|
||||||
|
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||||
|
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||||
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
|
return test{
|
||||||
|
linker: NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, accID, acc.ID)
|
||||||
|
return acc, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: 200,
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name, prep := range tests {
|
||||||
|
tc := prep(t)
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
h := &Handler{db: tc.db, linker: tc.linker}
|
||||||
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
|
req = req.WithContext(tc.ctx)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.extractOrLookupJWK(tc.next)(w, req)
|
||||||
|
res := w.Result()
|
||||||
|
|
||||||
|
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
res.Body.Close()
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||||
|
var ae acme.Error
|
||||||
|
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||||
|
|
||||||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||||||
|
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||||
|
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||||
|
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||||
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||||
|
} else {
|
||||||
|
assert.Equals(t, bytes.TrimSpace(body), testBody)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue