Add tests for extractOrLookupJWK middleware
This commit is contained in:
parent
3151255a25
commit
c7a9c13060
2 changed files with 191 additions and 0 deletions
|
@ -378,6 +378,9 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
|
|||
|
||||
// canExtractJWKFrom checks if the JWS has a JWK that can be extracted
|
||||
func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
|
||||
if jws == nil {
|
||||
return false
|
||||
}
|
||||
if len(jws.Signatures) == 0 {
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -1473,3 +1473,191 @@ func TestHandler_validateJWS(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_canExtractJWKFrom(t *testing.T) {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
type args struct {
|
||||
jws *jose.JSONWebSignature
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "no-jws",
|
||||
args: args{
|
||||
jws: nil,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no-signatures",
|
||||
args: args{
|
||||
jws: &jose.JSONWebSignature{
|
||||
Signatures: []jose.Signature{},
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no-jwk",
|
||||
args: args{
|
||||
jws: &jose.JSONWebSignature{
|
||||
Signatures: []jose.Signature{
|
||||
{
|
||||
Protected: jose.Header{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
jws: &jose.JSONWebSignature{
|
||||
Signatures: []jose.Signature{
|
||||
{
|
||||
Protected: jose.Header{
|
||||
JSONWebKey: jwk,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := canExtractJWKFrom(tt.args.jws); got != tt.want {
|
||||
t.Errorf("canExtractJWKFrom() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_extractOrLookupJWK(t *testing.T) {
|
||||
u := "https://ca.smallstep.com/acme/account"
|
||||
type test struct {
|
||||
db acme.DB
|
||||
linker Linker
|
||||
statusCode int
|
||||
ctx context.Context
|
||||
err *acme.Error
|
||||
next func(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
var tests = map[string]func(t *testing.T) test{
|
||||
"ok/extract": func(t *testing.T) test {
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
kid, err := jwk.Thumbprint(crypto.SHA256)
|
||||
assert.FatalError(t, err)
|
||||
pub := jwk.Public()
|
||||
pub.KeyID = base64.RawURLEncoding.EncodeToString(kid)
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithHeader("jwk", pub)
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
||||
Key: jwk.Key,
|
||||
}, so)
|
||||
assert.FatalError(t, err)
|
||||
signed, err := signer.Sign([]byte("foo"))
|
||||
assert.FatalError(t, err)
|
||||
raw, err := signed.CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
assert.FatalError(t, err)
|
||||
acc := &acme.Account{Status: "valid"}
|
||||
return test{
|
||||
linker: NewLinker("dns", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
|
||||
assert.Equals(t, kid, pub.KeyID)
|
||||
return acc, nil
|
||||
},
|
||||
},
|
||||
ctx: context.WithValue(context.Background(), jwsContextKey, parsedJWS),
|
||||
statusCode: 200,
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(testBody)
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/lookup": func(t *testing.T) test {
|
||||
prov := newProv()
|
||||
provName := url.PathEscape(prov.GetName())
|
||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
accID := "accID"
|
||||
prefix := fmt.Sprintf("%s/acme/%s/account/", baseURL, provName)
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithHeader("kid", fmt.Sprintf("%s%s", prefix, accID))
|
||||
signer, err := jose.NewSigner(jose.SigningKey{
|
||||
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
||||
Key: jwk.Key,
|
||||
}, so)
|
||||
assert.FatalError(t, err)
|
||||
jws, err := signer.Sign([]byte("baz"))
|
||||
assert.FatalError(t, err)
|
||||
raw, err := jws.CompactSerialize()
|
||||
assert.FatalError(t, err)
|
||||
parsedJWS, err := jose.ParseJWS(raw)
|
||||
assert.FatalError(t, err)
|
||||
assert.FatalError(t, err)
|
||||
//acc := &acme.Account{Status: "valid", Key: jwk}
|
||||
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
|
||||
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
|
||||
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
|
||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||
return test{
|
||||
linker: NewLinker("test.ca.smallstep.com", "acme"),
|
||||
db: &acme.MockDB{
|
||||
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||
assert.Equals(t, accID, acc.ID)
|
||||
return acc, nil
|
||||
},
|
||||
},
|
||||
ctx: ctx,
|
||||
statusCode: 200,
|
||||
next: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(testBody)
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
for name, prep := range tests {
|
||||
tc := prep(t)
|
||||
t.Run(name, func(t *testing.T) {
|
||||
h := &Handler{db: tc.db, linker: tc.linker}
|
||||
req := httptest.NewRequest("GET", u, nil)
|
||||
req = req.WithContext(tc.ctx)
|
||||
w := httptest.NewRecorder()
|
||||
h.extractOrLookupJWK(tc.next)(w, req)
|
||||
res := w.Result()
|
||||
|
||||
assert.Equals(t, res.StatusCode, tc.statusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
if res.StatusCode >= 400 && assert.NotNil(t, tc.err) {
|
||||
var ae acme.Error
|
||||
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &ae))
|
||||
|
||||
assert.Equals(t, ae.Type, tc.err.Type)
|
||||
assert.Equals(t, ae.Detail, tc.err.Detail)
|
||||
assert.Equals(t, ae.Identifier, tc.err.Identifier)
|
||||
assert.Equals(t, ae.Subproblems, tc.err.Subproblems)
|
||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/problem+json"})
|
||||
} else {
|
||||
assert.Equals(t, bytes.TrimSpace(body), testBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue