From 99702d36484eadd855b2aaf46e1c9945c84e985a Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Mon, 18 Apr 2022 21:14:30 +0200 Subject: [PATCH] Fix case of no authority policy existing --- authority/admin/db/nosql/policy.go | 59 +- authority/admin/db/nosql/policy_test.go | 737 ++++++++++++++++++++++++ authority/policy.go | 6 +- 3 files changed, 774 insertions(+), 28 deletions(-) create mode 100644 authority/admin/db/nosql/policy_test.go diff --git a/authority/admin/db/nosql/policy.go b/authority/admin/db/nosql/policy.go index d26e44a0..b309f50c 100644 --- a/authority/admin/db/nosql/policy.go +++ b/authority/admin/db/nosql/policy.go @@ -3,12 +3,12 @@ package nosql import ( "context" "encoding/json" + "fmt" - "github.com/pkg/errors" + "go.step.sm/linkedca" "github.com/smallstep/certificates/authority/admin" "github.com/smallstep/nosql" - "go.step.sm/linkedca" ) type dbAuthorityPolicy struct { @@ -18,32 +18,29 @@ type dbAuthorityPolicy struct { } func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy { + if dbap == nil { + return nil + } return dbap.Policy } -func (dbap *dbAuthorityPolicy) clone() *dbAuthorityPolicy { - u := *dbap - return &u -} - func (db *DB) getDBAuthorityPolicyBytes(ctx context.Context, authorityID string) ([]byte, error) { data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID)) if nosql.IsErrNotFound(err) { - return nil, admin.NewError(admin.ErrorNotFoundType, "policy %s not found", authorityID) + return nil, admin.NewError(admin.ErrorNotFoundType, "authority policy not found") } else if err != nil { - return nil, errors.Wrapf(err, "error loading admin %s", authorityID) + return nil, fmt.Errorf("error loading authority policy: %w", err) } return data, nil } -func (db *DB) unmarshalDBAuthorityPolicy(data []byte, authorityID string) (*dbAuthorityPolicy, error) { +func (db *DB) unmarshalDBAuthorityPolicy(data []byte) (*dbAuthorityPolicy, error) { + if len(data) == 0 { + return nil, nil + } var dba = new(dbAuthorityPolicy) if err := json.Unmarshal(data, dba); err != nil { - return nil, errors.Wrapf(err, "error unmarshaling admin %s into dbAdmin", authorityID) - } - if dba.AuthorityID != db.authorityID { - return nil, admin.NewError(admin.ErrorAuthorityMismatchType, - "admin %s is not owned by authority %s", dba.ID, db.authorityID) + return nil, fmt.Errorf("error unmarshaling policy bytes into dbAuthorityPolicy: %w", err) } return dba, nil } @@ -53,10 +50,17 @@ func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*db if err != nil { return nil, err } - dbap, err := db.unmarshalDBAuthorityPolicy(data, authorityID) + dbap, err := db.unmarshalDBAuthorityPolicy(data) if err != nil { return nil, err } + if dbap == nil { + return nil, nil + } + if dbap.AuthorityID != authorityID { + return nil, admin.NewError(admin.ErrorAuthorityMismatchType, + "authority policy is not owned by authority %s", authorityID) + } return dbap, nil } @@ -68,12 +72,11 @@ func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy Policy: policy, } - old, err := db.getDBAuthorityPolicy(ctx, db.authorityID) - if err != nil { - return err + if err := db.save(ctx, dbap.ID, dbap, nil, "authority_policy", authorityPoliciesTable); err != nil { + return admin.WrapErrorISE(err, "error creating authority policy") } - return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable) + return nil } func (db *DB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) { @@ -97,16 +100,22 @@ func (db *DB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy Policy: policy, } - return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable) + if err := db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable); err != nil { + return admin.WrapErrorISE(err, "error updating authority policy") + } + + return nil } func (db *DB) DeleteAuthorityPolicy(ctx context.Context) error { - dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID) + old, err := db.getDBAuthorityPolicy(ctx, db.authorityID) if err != nil { return err } - old := dbap.clone() - dbap.Policy = nil - return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable) + if err := db.save(ctx, old.ID, nil, old, "authority_policy", authorityPoliciesTable); err != nil { + return admin.WrapErrorISE(err, "error deleting authority policy") + } + + return nil } diff --git a/authority/admin/db/nosql/policy_test.go b/authority/admin/db/nosql/policy_test.go new file mode 100644 index 00000000..09bcd070 --- /dev/null +++ b/authority/admin/db/nosql/policy_test.go @@ -0,0 +1,737 @@ +package nosql + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "go.step.sm/linkedca" + + "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/admin" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + nosqldb "github.com/smallstep/nosql/database" +) + +func TestDB_getDBAuthorityPolicyBytes(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "ok": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return []byte("foo"), nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db} + if b, err := d.getDBAuthorityPolicyBytes(tc.ctx, tc.authorityID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, string(b), "foo") + } + }) + } +} + +func TestDB_getDBAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + db nosql.DB + err error + adminErr *admin.Error + dbap *dbAuthorityPolicy + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/unmarshal-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return []byte("foo"), nil + }, + }, + err: errors.New("error unmarshaling policy bytes into dbAuthorityPolicy"), + } + }, + "fail/authorityID-error": func(t *testing.T) test { + dbp := &dbAuthorityPolicy{ + ID: "ID", + AuthorityID: "diffAuthID", + Policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }, + } + b, err := json.Marshal(dbp) + assert.FatalError(t, err) + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return b, nil + }, + }, + adminErr: admin.NewError(admin.ErrorAuthorityMismatchType, + "authority policy is not owned by authority authID"), + } + }, + "ok/empty-bytes": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return []byte{}, nil + }, + }, + } + }, + "ok": func(t *testing.T) test { + dbap := &dbAuthorityPolicy{ + ID: "ID", + AuthorityID: authID, + Policy: &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + }, + } + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return b, nil + }, + }, + dbap: dbap, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} + if dbp, err := d.getDBAuthorityPolicy(tc.ctx, tc.authorityID); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) && tc.dbap == nil { + assert.Nil(t, dbp) + } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { + assert.Equals(t, dbp.ID, "ID") + assert.Equals(t, dbp.AuthorityID, tc.dbap.AuthorityID) + assert.Equals(t, dbp.Policy, tc.dbap.Policy) + } + }) + } +} + +func TestDB_CreateAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + policy *linkedca.Policy + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/save-error": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, policy) + + return nil, false, errors.New("force") + }, + }, + adminErr: admin.NewErrorISE("error creating authority policy: error saving authority authority_policy: force"), + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, old, nil) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, policy) + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + if err := d.CreateAuthorityPolicy(tc.ctx, tc.policy); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + } + }) + } +} + +func TestDB_GetAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + policy *linkedca.Policy + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "ok": func(t *testing.T) test { + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: policy, + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + got, err := d.GetAuthorityPolicy(tc.ctx) + if err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + return + } + + assert.NotNil(t, got) + assert.Equals(t, tc.policy, got) + }) + } +} + +func TestDB_UpdateAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + policy *linkedca.Policy + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "fail/save-error": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: oldPolicy, + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, policy) + + return nil, false, errors.New("force") + }, + }, + adminErr: admin.NewErrorISE("error updating authority policy: error saving authority authority_policy: force"), + } + }, + "ok": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + policy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.local"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + policy: policy, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: oldPolicy, + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + var _dbap = new(dbAuthorityPolicy) + assert.FatalError(t, json.Unmarshal(nu, _dbap)) + + assert.Equals(t, _dbap.ID, authID) + assert.Equals(t, _dbap.AuthorityID, authID) + assert.Equals(t, _dbap.Policy, policy) + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + if err := d.UpdateAuthorityPolicy(tc.ctx, tc.policy); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + return + } + }) + } +} + +func TestDB_DeleteAuthorityPolicy(t *testing.T) { + authID := "authID" + type test struct { + ctx context.Context + authorityID string + db nosql.DB + err error + adminErr *admin.Error + } + var tests = map[string]func(t *testing.T) test{ + "fail/not-found": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + return nil, nosqldb.ErrNotFound + }, + }, + adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"), + } + }, + "fail/db.Get-error": func(t *testing.T) test { + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + return nil, errors.New("force") + }, + }, + err: errors.New("error loading authority policy: force"), + } + }, + "fail/save-error": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: oldPolicy, + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + assert.Equals(t, nil, nu) + + return nil, false, errors.New("force") + }, + }, + adminErr: admin.NewErrorISE("error deleting authority policy: error saving authority authority_policy: force"), + } + }, + "ok": func(t *testing.T) test { + oldPolicy := &linkedca.Policy{ + X509: &linkedca.X509Policy{ + Allow: &linkedca.X509Names{ + Dns: []string{"*.localhost"}, + }, + }, + } + return test{ + ctx: context.Background(), + authorityID: authID, + db: &db.MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + + dbap := &dbAuthorityPolicy{ + ID: authID, + AuthorityID: authID, + Policy: oldPolicy, + } + + b, err := json.Marshal(dbap) + assert.FatalError(t, err) + + return b, nil + }, + MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { + assert.Equals(t, bucket, authorityPoliciesTable) + assert.Equals(t, string(key), authID) + assert.Equals(t, nil, nu) + + return nil, true, nil + }, + }, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + d := DB{db: tc.db, authorityID: tc.authorityID} + if err := d.DeleteAuthorityPolicy(tc.ctx); err != nil { + switch k := err.(type) { + case *admin.Error: + if assert.NotNil(t, tc.adminErr) { + assert.Equals(t, k.Type, tc.adminErr.Type) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + assert.Equals(t, k.Status, tc.adminErr.Status) + assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error()) + assert.Equals(t, k.Detail, tc.adminErr.Detail) + } + default: + if assert.NotNil(t, tc.err) { + assert.HasPrefix(t, err.Error(), tc.err.Error()) + } + } + return + } + }) + } +} diff --git a/authority/policy.go b/authority/policy.go index 1793fb9e..b7d5e4ec 100644 --- a/authority/policy.go +++ b/authority/policy.go @@ -69,7 +69,7 @@ func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Adm } } - return p, nil // TODO: return the newly stored policy + return p, nil } func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) { @@ -94,7 +94,7 @@ func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Adm } } - return p, nil // TODO: return the updated stored policy + return p, nil } func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error { @@ -111,7 +111,7 @@ func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error { if err := a.reloadPolicyEngines(ctx); err != nil { return &PolicyError{ Typ: ReloadFailure, - Err: fmt.Errorf("error reloading policy engines when deleting authority policy %w", err), + Err: fmt.Errorf("error reloading policy engines when deleting authority policy: %w", err), } }