forked from TrueCloudLab/certificates
Add provisioner controller tests.
This commit is contained in:
parent
fd6a2eeb9c
commit
3c2ff33ca9
1 changed files with 391 additions and 0 deletions
391
authority/provisioner/controller_test.go
Normal file
391
authority/provisioner/controller_test.go
Normal file
|
@ -0,0 +1,391 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var trueValue = true
|
||||
|
||||
func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer {
|
||||
t.Helper()
|
||||
c, err := NewClaimer(claims, global)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
func mustDuration(t *testing.T, s string) *Duration {
|
||||
t.Helper()
|
||||
d, err := NewDuration(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func TestNewController(t *testing.T) {
|
||||
type args struct {
|
||||
p Interface
|
||||
claims *Claims
|
||||
config Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Controller
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{&JWK{}, nil, Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}}, &Controller{
|
||||
Interface: &JWK{},
|
||||
Audiences: &testAudiences,
|
||||
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
||||
}, false},
|
||||
{"ok with claims", args{&JWK{}, &Claims{
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}, Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}}, &Controller{
|
||||
Interface: &JWK{},
|
||||
Audiences: &testAudiences,
|
||||
Claimer: mustClaimer(t, &Claims{
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}, globalProvisionerClaims),
|
||||
}, false},
|
||||
{"fail claimer", args{&JWK{}, &Claims{
|
||||
MinTLSDur: mustDuration(t, "24h"),
|
||||
MaxTLSDur: mustDuration(t, "2h"),
|
||||
}, Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewController() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestController_GetIdentity(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
type fields struct {
|
||||
Interface Interface
|
||||
IdentityFunc GetIdentityFunc
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
email string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *Identity
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{
|
||||
Usernames: []string{"jane", "jane@doe.org"},
|
||||
}, false},
|
||||
{"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
return &Identity{Usernames: []string{"jane"}}, nil
|
||||
}}, args{ctx, "jane@doe.org"}, &Identity{
|
||||
Usernames: []string{"jane"},
|
||||
}, false},
|
||||
{"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true},
|
||||
{"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
}}, args{ctx, "jane@doe.org"}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Controller{
|
||||
Interface: tt.fields.Interface,
|
||||
IdentityFunc: tt.fields.IdentityFunc,
|
||||
}
|
||||
got, err := c.GetIdentity(tt.args.ctx, tt.args.email)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestController_AuthorizeRenew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
type fields struct {
|
||||
Interface Interface
|
||||
Claimer *Claimer
|
||||
AuthorizeRenewFunc AuthorizeRenewFunc
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, false},
|
||||
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||
return nil
|
||||
}}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, false},
|
||||
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||
return nil
|
||||
}}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, false},
|
||||
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now.Add(-time.Hour),
|
||||
NotAfter: now.Add(-time.Minute),
|
||||
}}, false},
|
||||
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, true},
|
||||
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now.Add(time.Hour),
|
||||
NotAfter: now.Add(2 * time.Hour),
|
||||
}}, true},
|
||||
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now.Add(-time.Hour),
|
||||
NotAfter: now.Add(-time.Minute),
|
||||
}}, true},
|
||||
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||
return fmt.Errorf("an error")
|
||||
}}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Controller{
|
||||
Interface: tt.fields.Interface,
|
||||
Claimer: tt.fields.Claimer,
|
||||
AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc,
|
||||
}
|
||||
if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestController_AuthorizeSSHRenew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
type fields struct {
|
||||
Interface Interface
|
||||
Claimer *Claimer
|
||||
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
cert *ssh.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, false},
|
||||
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||
return nil
|
||||
}}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, false},
|
||||
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||
return nil
|
||||
}}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, false},
|
||||
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||
}}, false},
|
||||
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, true},
|
||||
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
|
||||
}}, true},
|
||||
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||
}}, true},
|
||||
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||
return fmt.Errorf("an error")
|
||||
}}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Controller{
|
||||
Interface: tt.fields.Interface,
|
||||
Claimer: tt.fields.Claimer,
|
||||
AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc,
|
||||
}
|
||||
if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAuthorizeRenew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
p *Controller
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
||||
}, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, false},
|
||||
{"ok renew after expiry", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
|
||||
}, &x509.Certificate{
|
||||
NotBefore: now.Add(-time.Hour),
|
||||
NotAfter: now.Add(-time.Minute),
|
||||
}}, false},
|
||||
{"fail disabled", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, true},
|
||||
{"fail not yet valid", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &x509.Certificate{
|
||||
NotBefore: now.Add(time.Hour),
|
||||
NotAfter: now.Add(2 * time.Hour),
|
||||
}}, true},
|
||||
{"fail expired", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &x509.Certificate{
|
||||
NotBefore: now.Add(-time.Hour),
|
||||
NotAfter: now.Add(-time.Minute),
|
||||
}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAuthorizeSSHRenew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
p *Controller
|
||||
cert *ssh.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
||||
}, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, false},
|
||||
{"ok renew after expiry", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
|
||||
}, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||
}}, false},
|
||||
{"fail disabled", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, true},
|
||||
{"fail not yet valid", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
|
||||
}}, true},
|
||||
{"fail expired", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||
}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue