From 3c2ff33ca90cb97fd861b9a479eb851e4feffab0 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Wed, 9 Mar 2022 18:43:27 -0800 Subject: [PATCH] Add provisioner controller tests. --- authority/provisioner/controller_test.go | 391 +++++++++++++++++++++++ 1 file changed, 391 insertions(+) create mode 100644 authority/provisioner/controller_test.go diff --git a/authority/provisioner/controller_test.go b/authority/provisioner/controller_test.go new file mode 100644 index 00000000..68f7055c --- /dev/null +++ b/authority/provisioner/controller_test.go @@ -0,0 +1,391 @@ +package provisioner + +import ( + "context" + "crypto/x509" + "fmt" + "reflect" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +var trueValue = true + +func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer { + t.Helper() + c, err := NewClaimer(claims, global) + if err != nil { + t.Fatal(err) + } + return c +} +func mustDuration(t *testing.T, s string) *Duration { + t.Helper() + d, err := NewDuration(s) + if err != nil { + t.Fatal(err) + } + return d +} + +func TestNewController(t *testing.T) { + type args struct { + p Interface + claims *Claims + config Config + } + tests := []struct { + name string + args args + want *Controller + wantErr bool + }{ + {"ok", args{&JWK{}, nil, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, &Controller{ + Interface: &JWK{}, + Audiences: &testAudiences, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, false}, + {"ok with claims", args{&JWK{}, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, &Controller{ + Interface: &JWK{}, + Audiences: &testAudiences, + Claimer: mustClaimer(t, &Claims{ + DisableRenewal: &defaultDisableRenewal, + }, globalProvisionerClaims), + }, false}, + {"fail claimer", args{&JWK{}, &Claims{ + MinTLSDur: mustDuration(t, "24h"), + MaxTLSDur: mustDuration(t, "2h"), + }, Config{ + Claims: globalProvisionerClaims, + Audiences: testAudiences, + }}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewController(tt.args.p, tt.args.claims, tt.args.config) + if (err != nil) != tt.wantErr { + t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewController() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestController_GetIdentity(t *testing.T) { + ctx := context.Background() + type fields struct { + Interface Interface + IdentityFunc GetIdentityFunc + } + type args struct { + ctx context.Context + email string + } + tests := []struct { + name string + fields fields + args args + want *Identity + wantErr bool + }{ + {"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{ + Usernames: []string{"jane", "jane@doe.org"}, + }, false}, + {"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) { + return &Identity{Usernames: []string{"jane"}}, nil + }}, args{ctx, "jane@doe.org"}, &Identity{ + Usernames: []string{"jane"}, + }, false}, + {"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true}, + {"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) { + return nil, fmt.Errorf("an error") + }}, args{ctx, "jane@doe.org"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + IdentityFunc: tt.fields.IdentityFunc, + } + got, err := c.GetIdentity(tt.args.ctx, tt.args.email) + if (err != nil) != tt.wantErr { + t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestController_AuthorizeRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type fields struct { + Interface Interface + Claimer *Claimer + AuthorizeRenewFunc AuthorizeRenewFunc + } + type args struct { + ctx context.Context + cert *x509.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return nil + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return nil + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, false}, + {"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + {"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(time.Hour), + NotAfter: now.Add(2 * time.Hour), + }}, true}, + {"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, true}, + {"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { + return fmt.Errorf("an error") + }}, args{ctx, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + Claimer: tt.fields.Claimer, + AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc, + } + if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestController_AuthorizeSSHRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type fields struct { + Interface Interface + Claimer *Claimer + AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc + } + type args struct { + ctx context.Context + cert *ssh.Certificate + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return nil + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return nil + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, false}, + {"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + {"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(time.Hour).Unix()), + ValidBefore: uint64(now.Add(2 * time.Hour).Unix()), + }}, true}, + {"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, true}, + {"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { + return fmt.Errorf("an error") + }}, args{ctx, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Controller{ + Interface: tt.fields.Interface, + Claimer: tt.fields.Claimer, + AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc, + } + if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDefaultAuthorizeRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type args struct { + ctx context.Context + p *Controller + cert *x509.Certificate + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, false}, + {"ok renew after expiry", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, false}, + {"fail disabled", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now, + NotAfter: now.Add(time.Hour), + }}, true}, + {"fail not yet valid", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(time.Hour), + NotAfter: now.Add(2 * time.Hour), + }}, true}, + {"fail expired", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &x509.Certificate{ + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(-time.Minute), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDefaultAuthorizeSSHRenew(t *testing.T) { + ctx := context.Background() + now := time.Now() + type args struct { + ctx context.Context + p *Controller + cert *ssh.Certificate + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, nil, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, false}, + {"ok renew after expiry", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{EnableRenewAfterExpiry: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, false}, + {"fail disabled", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Unix()), + ValidBefore: uint64(now.Add(time.Hour).Unix()), + }}, true}, + {"fail not yet valid", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(time.Hour).Unix()), + ValidBefore: uint64(now.Add(2 * time.Hour).Unix()), + }}, true}, + {"fail expired", args{ctx, &Controller{ + Interface: &JWK{}, + Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), + }, &ssh.Certificate{ + ValidAfter: uint64(now.Add(-time.Hour).Unix()), + ValidBefore: uint64(now.Add(-time.Minute).Unix()), + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr { + t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}