From 6264e8495cbb5ac3deeab8f4469682d8babf00f9 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Sun, 24 Apr 2022 16:29:31 +0200 Subject: [PATCH] Improve policy error handling code coverage --- authority/administrator/collection.go | 4 +- authority/policy.go | 27 +- authority/policy_test.go | 687 ++++++++++++++++++++++++++ 3 files changed, 697 insertions(+), 21 deletions(-) diff --git a/authority/administrator/collection.go b/authority/administrator/collection.go index 300c3e4f..f40e7417 100644 --- a/authority/administrator/collection.go +++ b/authority/administrator/collection.go @@ -59,12 +59,12 @@ func newSubProv(subject, prov string) subProv { return subProv{subject, prov} } -// LoadBySubProv a admin by the subject and provisioner name. +// LoadBySubProv loads an admin by subject and provisioner name. func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) { return loadAdmin(c.bySubProv, newSubProv(sub, provName)) } -// LoadByProvisioner a admin by the subject and provisioner name. +// LoadByProvisioner loads admins by provisioner name. func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) { val, ok := c.byProv.Load(provName) if !ok { diff --git a/authority/policy.go b/authority/policy.go index 9bcbd044..b5ee7949 100644 --- a/authority/policy.go +++ b/authority/policy.go @@ -39,7 +39,10 @@ func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, e p, err := a.adminDB.GetAuthorityPolicy(ctx) if err != nil { - return nil, err + return nil, &PolicyError{ + Typ: InternalFailure, + Err: err, + } } return p, nil @@ -50,10 +53,7 @@ func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Adm defer a.adminMutex.Unlock() if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil { - return nil, &PolicyError{ - Typ: AdminLockOut, - Err: err, - } + return nil, err } if err := a.adminDB.CreateAuthorityPolicy(ctx, p); err != nil { @@ -91,7 +91,7 @@ func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Adm if err := a.reloadPolicyEngines(ctx); err != nil { return nil, &PolicyError{ Typ: ReloadFailure, - Err: fmt.Errorf("error reloading policy engines when updating authority policy %w", err), + Err: fmt.Errorf("error reloading policy engines when updating authority policy: %w", err), } } @@ -145,14 +145,8 @@ func (a *Authority) checkProvisionerPolicy(ctx context.Context, currentAdmin *li return nil } - // get all admins for the provisioner - allProvisionerAdmins, ok := a.admins.LoadByProvisioner(provName) - if !ok { - return &PolicyError{ - Typ: InternalFailure, - Err: errors.New("error retrieving admins by provisioner"), - } - } + // get all admins for the provisioner; ignoring case in which they're not found + allProvisionerAdmins, _ := a.admins.LoadByProvisioner(provName) return a.checkPolicy(ctx, currentAdmin, allProvisionerAdmins, p) } @@ -222,11 +216,6 @@ func (a *Authority) reloadPolicyEngines(ctx context.Context) error { return nil } - // // temporarily only support the admin nosql DB - // if _, ok := a.adminDB.(*adminDBNosql.DB); !ok { - // return nil - // } - linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx) if err != nil { var ae *admin.Error diff --git a/authority/policy_test.go b/authority/policy_test.go index bc121a79..410c3ed3 100644 --- a/authority/policy_test.go +++ b/authority/policy_test.go @@ -3,6 +3,7 @@ package authority import ( "context" "errors" + "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -11,8 +12,11 @@ import ( "go.step.sm/linkedca" "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/authority/administrator" "github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/policy" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/db" ) func TestAuthority_checkPolicy(t *testing.T) { @@ -871,3 +875,686 @@ func TestAuthority_reloadPolicyEngines(t *testing.T) { }) } } + +func TestAuthority_checkAuthorityPolicy(t *testing.T) { + type fields struct { + provisioners *provisioner.Collection + admins *administrator.Collection + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + currentAdmin *linkedca.Admin + provName string + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "no policy", + fields: fields{}, + args: args{ + currentAdmin: nil, + provName: "prov", + p: nil, + }, + wantErr: false, + }, + { + name: "fail/adminDB.GetAdmins-error", + fields: fields{ + admins: administrator.NewCollection(nil), + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return nil, errors.New("force") + }, + }, + }, + args: args{ + currentAdmin: &linkedca.Admin{Subject: "step"}, + provName: "prov", + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ok", + fields: fields{ + admins: administrator.NewCollection(nil), + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + currentAdmin: &linkedca.Admin{Subject: "step"}, + provName: "prov", + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + provisioners: tt.fields.provisioners, + admins: tt.fields.admins, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + if err := a.checkAuthorityPolicy(tt.args.ctx, tt.args.currentAdmin, tt.args.p); (err != nil) != tt.wantErr { + t.Errorf("Authority.checkProvisionerPolicy() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAuthority_checkProvisionerPolicy(t *testing.T) { + type fields struct { + provisioners *provisioner.Collection + admins *administrator.Collection + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + currentAdmin *linkedca.Admin + provName string + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "no policy", + fields: fields{}, + args: args{ + currentAdmin: nil, + provName: "prov", + p: nil, + }, + wantErr: false, + }, + { + name: "ok", + fields: fields{ + admins: administrator.NewCollection(nil), + }, + args: args{ + currentAdmin: &linkedca.Admin{Subject: "step"}, + provName: "prov", + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + provisioners: tt.fields.provisioners, + admins: tt.fields.admins, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + if err := a.checkProvisionerPolicy(tt.args.ctx, tt.args.currentAdmin, tt.args.provName, tt.args.p); (err != nil) != tt.wantErr { + t.Errorf("Authority.checkProvisionerPolicy() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestAuthority_RemoveAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + wantErr *PolicyError + }{ + { + name: "fail/adminDB.DeleteAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockDeleteAuthorityPolicy: func(ctx context.Context) error { + return errors.New("force") + }, + }, + }, + wantErr: &PolicyError{ + Typ: StoreFailure, + Err: errors.New("force"), + }, + }, + { + name: "fail/a.reloadPolicyEngines", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockDeleteAuthorityPolicy: func(ctx context.Context) error { + return nil + }, + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + }, + }, + wantErr: &PolicyError{ + Typ: ReloadFailure, + Err: errors.New("error reloading policy engines when deleting authority policy: error getting policy to (re)load policy engines: force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockDeleteAuthorityPolicy: func(ctx context.Context) error { + return nil + }, + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, nil + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + err := a.RemoveAuthorityPolicy(tt.args.ctx) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + }) + } +} + +func TestAuthority_GetAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + want *linkedca.Policy + wantErr *PolicyError + }{ + { + name: "fail/adminDB.GetAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + }, + }, + wantErr: &PolicyError{ + Typ: InternalFailure, + Err: errors.New("force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{}, nil + }, + }, + }, + want: &linkedca.Policy{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + got, err := a.GetAuthorityPolicy(tt.args.ctx) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.GetAuthorityPolicy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_CreateAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + adm *linkedca.Admin + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + want *linkedca.Policy + wantErr *PolicyError + }{ + { + name: "fail/a.checkAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return nil, errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: InternalFailure, + Err: errors.New("error retrieving admins: force"), + }, + }, + { + name: "fail/adminDB.CreateAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + MockCreateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: StoreFailure, + Err: errors.New("force"), + }, + }, + { + name: "fail/a.reloadPolicyEngines", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: ReloadFailure, + Err: errors.New("error reloading policy engines when creating authority policy: error getting policy to (re)load policy engines: force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, nil + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + want: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + got, err := a.CreateAuthorityPolicy(tt.args.ctx, tt.args.adm, tt.args.p) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.CreateAuthorityPolicy() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthority_UpdateAuthorityPolicy(t *testing.T) { + type fields struct { + config *config.Config + db db.AuthDB + adminDB admin.DB + } + type args struct { + ctx context.Context + adm *linkedca.Admin + p *linkedca.Policy + } + tests := []struct { + name string + fields fields + args args + want *linkedca.Policy + wantErr *PolicyError + }{ + { + name: "fail/a.checkAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return nil, errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: InternalFailure, + Err: errors.New("error retrieving admins: force"), + }, + }, + { + name: "fail/adminDB.UpdateAuthorityPolicy", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { + return errors.New("force") + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: StoreFailure, + Err: errors.New("force"), + }, + }, + { + name: "fail/a.reloadPolicyEngines", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return nil, errors.New("force") + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + wantErr: &PolicyError{ + Typ: ReloadFailure, + Err: errors.New("error reloading policy engines when updating authority policy: error getting policy to (re)load policy engines: force"), + }, + }, + { + name: "ok", + fields: fields{ + config: &config.Config{ + AuthorityConfig: &config.AuthConfig{ + EnableAdmin: true, + }, + }, + adminDB: &admin.MockDB{ + MockGetAuthorityPolicy: func(ctx context.Context) (*linkedca.Policy, error) { + return &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, nil + }, + MockUpdateAuthorityPolicy: func(ctx context.Context, policy *linkedca.Policy) error { + return nil + }, + MockGetAdmins: func(ctx context.Context) ([]*linkedca.Admin, error) { + return []*linkedca.Admin{}, nil + }, + }, + }, + args: args{ + ctx: context.Background(), + adm: &linkedca.Admin{Subject: "step"}, + p: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + want: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"step", "otherAdmin"}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &Authority{ + config: tt.fields.config, + db: tt.fields.db, + adminDB: tt.fields.adminDB, + } + got, err := a.UpdateAuthorityPolicy(tt.args.ctx, tt.args.adm, tt.args.p) + if err != nil { + pe, ok := err.(*PolicyError) + assert.True(t, ok) + assert.Equal(t, tt.wantErr.Typ, pe.Typ) + assert.Equal(t, tt.wantErr.Err.Error(), pe.Err.Error()) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Authority.UpdateAuthorityPolicy() = %v, want %v", got, tt.want) + } + }) + } +}