2022-03-10 02:43:27 +00:00
package provisioner
import (
"context"
"crypto/x509"
"fmt"
"reflect"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
var trueValue = true
func mustClaimer ( t * testing . T , claims * Claims , global Claims ) * Claimer {
t . Helper ( )
c , err := NewClaimer ( claims , global )
if err != nil {
t . Fatal ( err )
}
return c
}
func mustDuration ( t * testing . T , s string ) * Duration {
t . Helper ( )
d , err := NewDuration ( s )
if err != nil {
t . Fatal ( err )
}
return d
}
func TestNewController ( t * testing . T ) {
type args struct {
p Interface
claims * Claims
config Config
}
tests := [ ] struct {
name string
args args
want * Controller
wantErr bool
} {
{ "ok" , args { & JWK { } , nil , Config {
Claims : globalProvisionerClaims ,
Audiences : testAudiences ,
} } , & Controller {
Interface : & JWK { } ,
Audiences : & testAudiences ,
Claimer : mustClaimer ( t , nil , globalProvisionerClaims ) ,
} , false } ,
{ "ok with claims" , args { & JWK { } , & Claims {
DisableRenewal : & defaultDisableRenewal ,
} , Config {
Claims : globalProvisionerClaims ,
Audiences : testAudiences ,
} } , & Controller {
Interface : & JWK { } ,
Audiences : & testAudiences ,
Claimer : mustClaimer ( t , & Claims {
DisableRenewal : & defaultDisableRenewal ,
} , globalProvisionerClaims ) ,
} , false } ,
{ "fail claimer" , args { & JWK { } , & Claims {
MinTLSDur : mustDuration ( t , "24h" ) ,
MaxTLSDur : mustDuration ( t , "2h" ) ,
} , Config {
Claims : globalProvisionerClaims ,
Audiences : testAudiences ,
} } , nil , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
got , err := NewController ( tt . args . p , tt . args . claims , tt . args . config )
if ( err != nil ) != tt . wantErr {
t . Errorf ( "NewController() error = %v, wantErr %v" , err , tt . wantErr )
return
}
if ! reflect . DeepEqual ( got , tt . want ) {
t . Errorf ( "NewController() = %v, want %v" , got , tt . want )
}
} )
}
}
func TestController_GetIdentity ( t * testing . T ) {
ctx := context . Background ( )
type fields struct {
Interface Interface
IdentityFunc GetIdentityFunc
}
type args struct {
ctx context . Context
email string
}
tests := [ ] struct {
name string
fields fields
args args
want * Identity
wantErr bool
} {
{ "ok" , fields { & OIDC { } , nil } , args { ctx , "jane@doe.org" } , & Identity {
Usernames : [ ] string { "jane" , "jane@doe.org" } ,
} , false } ,
{ "ok custom" , fields { & OIDC { } , func ( ctx context . Context , p Interface , email string ) ( * Identity , error ) {
return & Identity { Usernames : [ ] string { "jane" } } , nil
} } , args { ctx , "jane@doe.org" } , & Identity {
Usernames : [ ] string { "jane" } ,
} , false } ,
{ "fail provisioner" , fields { & JWK { } , nil } , args { ctx , "jane@doe.org" } , nil , true } ,
{ "fail custom" , fields { & OIDC { } , func ( ctx context . Context , p Interface , email string ) ( * Identity , error ) {
return nil , fmt . Errorf ( "an error" )
} } , args { ctx , "jane@doe.org" } , nil , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c := & Controller {
Interface : tt . fields . Interface ,
IdentityFunc : tt . fields . IdentityFunc ,
}
got , err := c . GetIdentity ( tt . args . ctx , tt . args . email )
if ( err != nil ) != tt . wantErr {
t . Errorf ( "Controller.GetIdentity() error = %v, wantErr %v" , err , tt . wantErr )
return
}
if ! reflect . DeepEqual ( got , tt . want ) {
t . Errorf ( "Controller.GetIdentity() = %v, want %v" , got , tt . want )
}
} )
}
}
func TestController_AuthorizeRenew ( t * testing . T ) {
ctx := context . Background ( )
2022-03-10 18:46:28 +00:00
now := time . Now ( ) . Truncate ( time . Second )
2022-03-10 02:43:27 +00:00
type fields struct {
Interface Interface
Claimer * Claimer
AuthorizeRenewFunc AuthorizeRenewFunc
}
type args struct {
ctx context . Context
cert * x509 . Certificate
}
tests := [ ] struct {
name string
fields fields
args args
wantErr bool
} {
{ "ok" , fields { & JWK { } , mustClaimer ( t , nil , globalProvisionerClaims ) , nil } , args { ctx , & x509 . Certificate {
NotBefore : now ,
NotAfter : now . Add ( time . Hour ) ,
} } , false } ,
{ "ok custom" , fields { & JWK { } , mustClaimer ( t , nil , globalProvisionerClaims ) , func ( ctx context . Context , p * Controller , cert * x509 . Certificate ) error {
return nil
} } , args { ctx , & x509 . Certificate {
NotBefore : now ,
NotAfter : now . Add ( time . Hour ) ,
} } , false } ,
{ "ok custom disabled" , fields { & JWK { } , mustClaimer ( t , & Claims { EnableRenewAfterExpiry : & trueValue } , globalProvisionerClaims ) , func ( ctx context . Context , p * Controller , cert * x509 . Certificate ) error {
return nil
} } , args { ctx , & x509 . Certificate {
NotBefore : now ,
NotAfter : now . Add ( time . Hour ) ,
} } , false } ,
{ "ok renew after expiry" , fields { & JWK { } , mustClaimer ( t , & Claims { EnableRenewAfterExpiry : & trueValue } , globalProvisionerClaims ) , nil } , args { ctx , & x509 . Certificate {
NotBefore : now . Add ( - time . Hour ) ,
NotAfter : now . Add ( - time . Minute ) ,
} } , false } ,
{ "fail disabled" , fields { & JWK { } , mustClaimer ( t , & Claims { DisableRenewal : & trueValue } , globalProvisionerClaims ) , nil } , args { ctx , & x509 . Certificate {
NotBefore : now ,
NotAfter : now . Add ( time . Hour ) ,
} } , true } ,
{ "fail not yet valid" , fields { & JWK { } , mustClaimer ( t , nil , globalProvisionerClaims ) , nil } , args { ctx , & x509 . Certificate {
NotBefore : now . Add ( time . Hour ) ,
NotAfter : now . Add ( 2 * time . Hour ) ,
} } , true } ,
{ "fail expired" , fields { & JWK { } , mustClaimer ( t , nil , globalProvisionerClaims ) , nil } , args { ctx , & x509 . Certificate {
NotBefore : now . Add ( - time . Hour ) ,
NotAfter : now . Add ( - time . Minute ) ,
} } , true } ,
{ "fail custom" , fields { & JWK { } , mustClaimer ( t , nil , globalProvisionerClaims ) , func ( ctx context . Context , p * Controller , cert * x509 . Certificate ) error {
return fmt . Errorf ( "an error" )
} } , args { ctx , & x509 . Certificate {
NotBefore : now ,
NotAfter : now . Add ( time . Hour ) ,
} } , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c := & Controller {
Interface : tt . fields . Interface ,
Claimer : tt . fields . Claimer ,
AuthorizeRenewFunc : tt . fields . AuthorizeRenewFunc ,
}
if err := c . AuthorizeRenew ( tt . args . ctx , tt . args . cert ) ; ( err != nil ) != tt . wantErr {
t . Errorf ( "Controller.AuthorizeRenew() error = %v, wantErr %v" , err , tt . wantErr )
}
} )
}
}
func TestController_AuthorizeSSHRenew ( t * testing . T ) {
ctx := context . Background ( )
now := time . Now ( )
type fields struct {
Interface Interface
Claimer * Claimer
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
}
type args struct {
ctx context . Context
cert * ssh . Certificate
}
tests := [ ] struct {
name string
fields fields
args args
wantErr bool
} {
{ "ok" , fields { & JWK { } , mustClaimer ( t , nil , globalProvisionerClaims ) , nil } , args { ctx , & ssh . Certificate {
ValidAfter : uint64 ( now . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( time . Hour ) . Unix ( ) ) ,
} } , false } ,
{ "ok custom" , fields { & JWK { } , mustClaimer ( t , nil , globalProvisionerClaims ) , func ( ctx context . Context , p * Controller , cert * ssh . Certificate ) error {
return nil
} } , args { ctx , & ssh . Certificate {
ValidAfter : uint64 ( now . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( time . Hour ) . Unix ( ) ) ,
} } , false } ,
{ "ok custom disabled" , fields { & JWK { } , mustClaimer ( t , & Claims { EnableRenewAfterExpiry : & trueValue } , globalProvisionerClaims ) , func ( ctx context . Context , p * Controller , cert * ssh . Certificate ) error {
return nil
} } , args { ctx , & ssh . Certificate {
ValidAfter : uint64 ( now . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( time . Hour ) . Unix ( ) ) ,
} } , false } ,
{ "ok renew after expiry" , fields { & JWK { } , mustClaimer ( t , & Claims { EnableRenewAfterExpiry : & trueValue } , globalProvisionerClaims ) , nil } , args { ctx , & ssh . Certificate {
ValidAfter : uint64 ( now . Add ( - time . Hour ) . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( - time . Minute ) . Unix ( ) ) ,
} } , false } ,
{ "fail disabled" , fields { & JWK { } , mustClaimer ( t , & Claims { DisableRenewal : & trueValue } , globalProvisionerClaims ) , nil } , args { ctx , & ssh . Certificate {
ValidAfter : uint64 ( now . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( time . Hour ) . Unix ( ) ) ,
} } , true } ,
{ "fail not yet valid" , fields { & JWK { } , mustClaimer ( t , nil , globalProvisionerClaims ) , nil } , args { ctx , & ssh . Certificate {
ValidAfter : uint64 ( now . Add ( time . Hour ) . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( 2 * time . Hour ) . Unix ( ) ) ,
} } , true } ,
{ "fail expired" , fields { & JWK { } , mustClaimer ( t , nil , globalProvisionerClaims ) , nil } , args { ctx , & ssh . Certificate {
ValidAfter : uint64 ( now . Add ( - time . Hour ) . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( - time . Minute ) . Unix ( ) ) ,
} } , true } ,
{ "fail custom" , fields { & JWK { } , mustClaimer ( t , nil , globalProvisionerClaims ) , func ( ctx context . Context , p * Controller , cert * ssh . Certificate ) error {
return fmt . Errorf ( "an error" )
} } , args { ctx , & ssh . Certificate {
ValidAfter : uint64 ( now . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( time . Hour ) . Unix ( ) ) ,
} } , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
c := & Controller {
Interface : tt . fields . Interface ,
Claimer : tt . fields . Claimer ,
AuthorizeSSHRenewFunc : tt . fields . AuthorizeSSHRenewFunc ,
}
if err := c . AuthorizeSSHRenew ( tt . args . ctx , tt . args . cert ) ; ( err != nil ) != tt . wantErr {
t . Errorf ( "Controller.AuthorizeSSHRenew() error = %v, wantErr %v" , err , tt . wantErr )
}
} )
}
}
func TestDefaultAuthorizeRenew ( t * testing . T ) {
ctx := context . Background ( )
2022-03-10 18:46:28 +00:00
now := time . Now ( ) . Truncate ( time . Second )
2022-03-10 02:43:27 +00:00
type args struct {
ctx context . Context
p * Controller
cert * x509 . Certificate
}
tests := [ ] struct {
name string
args args
wantErr bool
} {
{ "ok" , args { ctx , & Controller {
Interface : & JWK { } ,
Claimer : mustClaimer ( t , nil , globalProvisionerClaims ) ,
} , & x509 . Certificate {
NotBefore : now ,
NotAfter : now . Add ( time . Hour ) ,
} } , false } ,
{ "ok renew after expiry" , args { ctx , & Controller {
Interface : & JWK { } ,
Claimer : mustClaimer ( t , & Claims { EnableRenewAfterExpiry : & trueValue } , globalProvisionerClaims ) ,
} , & x509 . Certificate {
NotBefore : now . Add ( - time . Hour ) ,
NotAfter : now . Add ( - time . Minute ) ,
} } , false } ,
{ "fail disabled" , args { ctx , & Controller {
Interface : & JWK { } ,
Claimer : mustClaimer ( t , & Claims { DisableRenewal : & trueValue } , globalProvisionerClaims ) ,
} , & x509 . Certificate {
NotBefore : now ,
NotAfter : now . Add ( time . Hour ) ,
} } , true } ,
{ "fail not yet valid" , args { ctx , & Controller {
Interface : & JWK { } ,
Claimer : mustClaimer ( t , & Claims { DisableRenewal : & trueValue } , globalProvisionerClaims ) ,
} , & x509 . Certificate {
NotBefore : now . Add ( time . Hour ) ,
NotAfter : now . Add ( 2 * time . Hour ) ,
} } , true } ,
{ "fail expired" , args { ctx , & Controller {
Interface : & JWK { } ,
Claimer : mustClaimer ( t , & Claims { DisableRenewal : & trueValue } , globalProvisionerClaims ) ,
} , & x509 . Certificate {
NotBefore : now . Add ( - time . Hour ) ,
NotAfter : now . Add ( - time . Minute ) ,
} } , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
if err := DefaultAuthorizeRenew ( tt . args . ctx , tt . args . p , tt . args . cert ) ; ( err != nil ) != tt . wantErr {
t . Errorf ( "DefaultAuthorizeRenew() error = %v, wantErr %v" , err , tt . wantErr )
}
} )
}
}
func TestDefaultAuthorizeSSHRenew ( t * testing . T ) {
ctx := context . Background ( )
now := time . Now ( )
type args struct {
ctx context . Context
p * Controller
cert * ssh . Certificate
}
tests := [ ] struct {
name string
args args
wantErr bool
} {
{ "ok" , args { ctx , & Controller {
Interface : & JWK { } ,
Claimer : mustClaimer ( t , nil , globalProvisionerClaims ) ,
} , & ssh . Certificate {
ValidAfter : uint64 ( now . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( time . Hour ) . Unix ( ) ) ,
} } , false } ,
{ "ok renew after expiry" , args { ctx , & Controller {
Interface : & JWK { } ,
Claimer : mustClaimer ( t , & Claims { EnableRenewAfterExpiry : & trueValue } , globalProvisionerClaims ) ,
} , & ssh . Certificate {
ValidAfter : uint64 ( now . Add ( - time . Hour ) . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( - time . Minute ) . Unix ( ) ) ,
} } , false } ,
{ "fail disabled" , args { ctx , & Controller {
Interface : & JWK { } ,
Claimer : mustClaimer ( t , & Claims { DisableRenewal : & trueValue } , globalProvisionerClaims ) ,
} , & ssh . Certificate {
ValidAfter : uint64 ( now . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( time . Hour ) . Unix ( ) ) ,
} } , true } ,
{ "fail not yet valid" , args { ctx , & Controller {
Interface : & JWK { } ,
Claimer : mustClaimer ( t , & Claims { DisableRenewal : & trueValue } , globalProvisionerClaims ) ,
} , & ssh . Certificate {
ValidAfter : uint64 ( now . Add ( time . Hour ) . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( 2 * time . Hour ) . Unix ( ) ) ,
} } , true } ,
{ "fail expired" , args { ctx , & Controller {
Interface : & JWK { } ,
Claimer : mustClaimer ( t , & Claims { DisableRenewal : & trueValue } , globalProvisionerClaims ) ,
} , & ssh . Certificate {
ValidAfter : uint64 ( now . Add ( - time . Hour ) . Unix ( ) ) ,
ValidBefore : uint64 ( now . Add ( - time . Minute ) . Unix ( ) ) ,
} } , true } ,
}
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
if err := DefaultAuthorizeSSHRenew ( tt . args . ctx , tt . args . p , tt . args . cert ) ; ( err != nil ) != tt . wantErr {
t . Errorf ( "DefaultAuthorizeSSHRenew() error = %v, wantErr %v" , err , tt . wantErr )
}
} )
}
}