forked from TrueCloudLab/certificates
391 lines
13 KiB
Go
391 lines
13 KiB
Go
package provisioner
|
|
|
|
import (
|
|
"context"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
var trueValue = true
|
|
|
|
func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer {
|
|
t.Helper()
|
|
c, err := NewClaimer(claims, global)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return c
|
|
}
|
|
func mustDuration(t *testing.T, s string) *Duration {
|
|
t.Helper()
|
|
d, err := NewDuration(s)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return d
|
|
}
|
|
|
|
func TestNewController(t *testing.T) {
|
|
type args struct {
|
|
p Interface
|
|
claims *Claims
|
|
config Config
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want *Controller
|
|
wantErr bool
|
|
}{
|
|
{"ok", args{&JWK{}, nil, Config{
|
|
Claims: globalProvisionerClaims,
|
|
Audiences: testAudiences,
|
|
}}, &Controller{
|
|
Interface: &JWK{},
|
|
Audiences: &testAudiences,
|
|
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
|
}, false},
|
|
{"ok with claims", args{&JWK{}, &Claims{
|
|
DisableRenewal: &defaultDisableRenewal,
|
|
}, Config{
|
|
Claims: globalProvisionerClaims,
|
|
Audiences: testAudiences,
|
|
}}, &Controller{
|
|
Interface: &JWK{},
|
|
Audiences: &testAudiences,
|
|
Claimer: mustClaimer(t, &Claims{
|
|
DisableRenewal: &defaultDisableRenewal,
|
|
}, globalProvisionerClaims),
|
|
}, false},
|
|
{"fail claimer", args{&JWK{}, &Claims{
|
|
MinTLSDur: mustDuration(t, "24h"),
|
|
MaxTLSDur: mustDuration(t, "2h"),
|
|
}, Config{
|
|
Claims: globalProvisionerClaims,
|
|
Audiences: testAudiences,
|
|
}}, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("NewController() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestController_GetIdentity(t *testing.T) {
|
|
ctx := context.Background()
|
|
type fields struct {
|
|
Interface Interface
|
|
IdentityFunc GetIdentityFunc
|
|
}
|
|
type args struct {
|
|
ctx context.Context
|
|
email string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
want *Identity
|
|
wantErr bool
|
|
}{
|
|
{"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{
|
|
Usernames: []string{"jane", "jane@doe.org"},
|
|
}, false},
|
|
{"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
|
return &Identity{Usernames: []string{"jane"}}, nil
|
|
}}, args{ctx, "jane@doe.org"}, &Identity{
|
|
Usernames: []string{"jane"},
|
|
}, false},
|
|
{"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true},
|
|
{"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
|
return nil, fmt.Errorf("an error")
|
|
}}, args{ctx, "jane@doe.org"}, nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
c := &Controller{
|
|
Interface: tt.fields.Interface,
|
|
IdentityFunc: tt.fields.IdentityFunc,
|
|
}
|
|
got, err := c.GetIdentity(tt.args.ctx, tt.args.email)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestController_AuthorizeRenew(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now().Truncate(time.Second)
|
|
type fields struct {
|
|
Interface Interface
|
|
Claimer *Claimer
|
|
AuthorizeRenewFunc AuthorizeRenewFunc
|
|
}
|
|
type args struct {
|
|
ctx context.Context
|
|
cert *x509.Certificate
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
wantErr bool
|
|
}{
|
|
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
|
NotBefore: now,
|
|
NotAfter: now.Add(time.Hour),
|
|
}}, false},
|
|
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
|
return nil
|
|
}}, args{ctx, &x509.Certificate{
|
|
NotBefore: now,
|
|
NotAfter: now.Add(time.Hour),
|
|
}}, false},
|
|
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
|
return nil
|
|
}}, args{ctx, &x509.Certificate{
|
|
NotBefore: now,
|
|
NotAfter: now.Add(time.Hour),
|
|
}}, false},
|
|
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
|
NotBefore: now.Add(-time.Hour),
|
|
NotAfter: now.Add(-time.Minute),
|
|
}}, false},
|
|
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
|
NotBefore: now,
|
|
NotAfter: now.Add(time.Hour),
|
|
}}, true},
|
|
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
|
NotBefore: now.Add(time.Hour),
|
|
NotAfter: now.Add(2 * time.Hour),
|
|
}}, true},
|
|
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
|
NotBefore: now.Add(-time.Hour),
|
|
NotAfter: now.Add(-time.Minute),
|
|
}}, true},
|
|
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
|
return fmt.Errorf("an error")
|
|
}}, args{ctx, &x509.Certificate{
|
|
NotBefore: now,
|
|
NotAfter: now.Add(time.Hour),
|
|
}}, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
c := &Controller{
|
|
Interface: tt.fields.Interface,
|
|
Claimer: tt.fields.Claimer,
|
|
AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc,
|
|
}
|
|
if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
|
|
t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestController_AuthorizeSSHRenew(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now()
|
|
type fields struct {
|
|
Interface Interface
|
|
Claimer *Claimer
|
|
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
|
}
|
|
type args struct {
|
|
ctx context.Context
|
|
cert *ssh.Certificate
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
fields fields
|
|
args args
|
|
wantErr bool
|
|
}{
|
|
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Unix()),
|
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
|
}}, false},
|
|
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
|
return nil
|
|
}}, args{ctx, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Unix()),
|
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
|
}}, false},
|
|
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
|
return nil
|
|
}}, args{ctx, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Unix()),
|
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
|
}}, false},
|
|
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
|
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
|
}}, false},
|
|
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Unix()),
|
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
|
}}, true},
|
|
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Add(time.Hour).Unix()),
|
|
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
|
|
}}, true},
|
|
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
|
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
|
}}, true},
|
|
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
|
return fmt.Errorf("an error")
|
|
}}, args{ctx, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Unix()),
|
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
|
}}, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
c := &Controller{
|
|
Interface: tt.fields.Interface,
|
|
Claimer: tt.fields.Claimer,
|
|
AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc,
|
|
}
|
|
if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
|
|
t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDefaultAuthorizeRenew(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now().Truncate(time.Second)
|
|
type args struct {
|
|
ctx context.Context
|
|
p *Controller
|
|
cert *x509.Certificate
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
wantErr bool
|
|
}{
|
|
{"ok", args{ctx, &Controller{
|
|
Interface: &JWK{},
|
|
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
|
}, &x509.Certificate{
|
|
NotBefore: now,
|
|
NotAfter: now.Add(time.Hour),
|
|
}}, false},
|
|
{"ok renew after expiry", args{ctx, &Controller{
|
|
Interface: &JWK{},
|
|
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
|
|
}, &x509.Certificate{
|
|
NotBefore: now.Add(-time.Hour),
|
|
NotAfter: now.Add(-time.Minute),
|
|
}}, false},
|
|
{"fail disabled", args{ctx, &Controller{
|
|
Interface: &JWK{},
|
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
|
}, &x509.Certificate{
|
|
NotBefore: now,
|
|
NotAfter: now.Add(time.Hour),
|
|
}}, true},
|
|
{"fail not yet valid", args{ctx, &Controller{
|
|
Interface: &JWK{},
|
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
|
}, &x509.Certificate{
|
|
NotBefore: now.Add(time.Hour),
|
|
NotAfter: now.Add(2 * time.Hour),
|
|
}}, true},
|
|
{"fail expired", args{ctx, &Controller{
|
|
Interface: &JWK{},
|
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
|
}, &x509.Certificate{
|
|
NotBefore: now.Add(-time.Hour),
|
|
NotAfter: now.Add(-time.Minute),
|
|
}}, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
|
|
t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDefaultAuthorizeSSHRenew(t *testing.T) {
|
|
ctx := context.Background()
|
|
now := time.Now()
|
|
type args struct {
|
|
ctx context.Context
|
|
p *Controller
|
|
cert *ssh.Certificate
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
wantErr bool
|
|
}{
|
|
{"ok", args{ctx, &Controller{
|
|
Interface: &JWK{},
|
|
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
|
}, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Unix()),
|
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
|
}}, false},
|
|
{"ok renew after expiry", args{ctx, &Controller{
|
|
Interface: &JWK{},
|
|
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
|
|
}, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
|
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
|
}}, false},
|
|
{"fail disabled", args{ctx, &Controller{
|
|
Interface: &JWK{},
|
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
|
}, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Unix()),
|
|
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
|
}}, true},
|
|
{"fail not yet valid", args{ctx, &Controller{
|
|
Interface: &JWK{},
|
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
|
}, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Add(time.Hour).Unix()),
|
|
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
|
|
}}, true},
|
|
{"fail expired", args{ctx, &Controller{
|
|
Interface: &JWK{},
|
|
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
|
}, &ssh.Certificate{
|
|
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
|
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
|
}}, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
|
|
t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|