certificates/authority/provisioner/k8sSA_test.go
2019-10-29 17:42:50 -07:00

264 lines
6.9 KiB
Go

package provisioner
import (
"context"
"crypto/x509"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/cli/jose"
)
func TestK8sSA_Getters(t *testing.T) {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
id := "k8ssa/" + p.Name
if got := p.GetID(); got != id {
t.Errorf("K8sSA.GetID() = %v, want %v", got, id)
}
if got := p.GetName(); got != p.Name {
t.Errorf("K8sSA.GetName() = %v, want %v", got, p.Name)
}
if got := p.GetType(); got != TypeK8sSA {
t.Errorf("K8sSA.GetType() = %v, want %v", got, TypeK8sSA)
}
kid, key, ok := p.GetEncryptedKey()
if kid != "" || key != "" || ok == true {
t.Errorf("K8sSA.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
kid, key, ok, "", "", false)
}
}
func TestK8sSA_authorizeToken(t *testing.T) {
type test struct {
p *K8sSA
token string
err error
}
tests := map[string]func(*testing.T) test{
"fail/bad-token": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
err: errors.New("error parsing token"),
}
},
"fail/error-validating-token": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
tok, err := generateToken("", p.Name, testAudiences.Sign[0], "",
[]string{"test.smallstep.com"}, time.Now(), jwk)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
err: errors.New("error validating token and extracting claims"),
}
},
"fail/invalid-issuer": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(jwk.Public().Key)
assert.FatalError(t, err)
claims := getK8sSAPayload()
claims.Claims.Issuer = "invalid"
tok, err := generateK8sSAToken(jwk, claims)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
err: errors.New("invalid token claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"),
}
},
"ok": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(jwk.Public().Key)
assert.FatalError(t, err)
tok, err := generateK8sSAToken(jwk, nil)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.NotNil(t, claims)
}
}
})
}
}
func TestK8sSA_AuthorizeSign(t *testing.T) {
type test struct {
p *K8sSA
token string
ctx context.Context
err error
}
tests := map[string]func(*testing.T) test{
"fail/invalid-token": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
err: errors.New("error parsing token"),
}
},
"fail/ssh-unimplemented": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(jwk.Public().Key)
assert.FatalError(t, err)
tok, err := generateK8sSAToken(jwk, nil)
assert.FatalError(t, err)
return test{
p: p,
ctx: NewContextWithMethod(context.Background(), SignSSHMethod),
token: tok,
err: errors.Errorf("ssh certificates not enabled for k8s ServiceAccount provisioners"),
}
},
"ok": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(jwk.Public().Key)
assert.FatalError(t, err)
tok, err := generateK8sSAToken(jwk, nil)
assert.FatalError(t, err)
return test{
p: p,
ctx: NewContextWithMethod(context.Background(), SignMethod),
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if opts, err := tc.p.AuthorizeSign(tc.ctx, tc.token); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) {
tot := 0
for _, o := range opts {
switch v := o.(type) {
case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeK8sSA))
assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator:
case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
tot++
}
assert.Equals(t, tot, 4)
}
}
}
})
}
}
func TestK8sSA_AuthorizeRevoke(t *testing.T) {
type test struct {
p *K8sSA
token string
err error
}
tests := map[string]func(*testing.T) test{
"fail/invalid-token": func(t *testing.T) test {
p, err := generateK8sSA(nil)
assert.FatalError(t, err)
return test{
p: p,
token: "foo",
err: errors.New("error parsing token"),
}
},
"ok": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
p, err := generateK8sSA(jwk.Public().Key)
assert.FatalError(t, err)
tok, err := generateK8sSAToken(jwk, nil)
assert.FatalError(t, err)
return test{
p: p,
token: tok,
}
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
tc := tt(t)
if err := tc.p.AuthorizeRevoke(tc.token); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestK8sSA_AuthorizeRenewal(t *testing.T) {
p1, err := generateK8sSA(nil)
assert.FatalError(t, err)
p2, err := generateK8sSA(nil)
assert.FatalError(t, err)
// disable renewal
disable := true
p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err)
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
prov *K8sSA
args args
wantErr bool
}{
{"ok", p1, args{nil}, false},
{"fail", p2, args{nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.prov.AuthorizeRenewal(tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("X5C.AuthorizeRenewal() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}