certificates/authority/admin/db/nosql/policy_test.go
2022-04-18 21:47:13 +02:00

739 lines
20 KiB
Go

package nosql
import (
"context"
"encoding/json"
"errors"
"testing"
"go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
nosqldb "github.com/smallstep/nosql/database"
)
func TestDB_getDBAuthorityPolicyBytes(t *testing.T) {
authID := "authID"
type test struct {
ctx context.Context
authorityID string
db nosql.DB
err error
adminErr *admin.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return nil, nosqldb.ErrNotFound
},
},
adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"),
}
},
"fail/db.Get-error": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return nil, errors.New("force")
},
},
err: errors.New("error loading authority policy: force"),
}
},
"ok": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return []byte("foo"), nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db}
if b, err := d.getDBAuthorityPolicyBytes(tc.ctx, tc.authorityID); err != nil {
switch k := err.(type) {
case *admin.Error:
if assert.NotNil(t, tc.adminErr) {
assert.Equals(t, k.Type, tc.adminErr.Type)
assert.Equals(t, k.Detail, tc.adminErr.Detail)
assert.Equals(t, k.Status, tc.adminErr.Status)
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
assert.Equals(t, k.Detail, tc.adminErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
assert.Equals(t, string(b), "foo")
}
})
}
}
func TestDB_getDBAuthorityPolicy(t *testing.T) {
authID := "authID"
type test struct {
ctx context.Context
authorityID string
db nosql.DB
err error
adminErr *admin.Error
dbap *dbAuthorityPolicy
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return nil, nosqldb.ErrNotFound
},
},
adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return []byte("foo"), nil
},
},
err: errors.New("error unmarshaling policy bytes into dbAuthorityPolicy"),
}
},
"fail/authorityID-error": func(t *testing.T) test {
dbp := &dbAuthorityPolicy{
ID: "ID",
AuthorityID: "diffAuthID",
Policy: &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
},
},
},
}
b, err := json.Marshal(dbp)
assert.FatalError(t, err)
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return b, nil
},
},
adminErr: admin.NewError(admin.ErrorAuthorityMismatchType,
"authority policy is not owned by authority authID"),
}
},
"ok/empty-bytes": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return []byte{}, nil
},
},
}
},
"ok": func(t *testing.T) test {
dbap := &dbAuthorityPolicy{
ID: "ID",
AuthorityID: authID,
Policy: &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
},
},
},
}
b, err := json.Marshal(dbap)
assert.FatalError(t, err)
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return b, nil
},
},
dbap: dbap,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
dbp, err := d.getDBAuthorityPolicy(tc.ctx, tc.authorityID)
switch {
case err != nil:
switch k := err.(type) {
case *admin.Error:
if assert.NotNil(t, tc.adminErr) {
assert.Equals(t, k.Type, tc.adminErr.Type)
assert.Equals(t, k.Detail, tc.adminErr.Detail)
assert.Equals(t, k.Status, tc.adminErr.Status)
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
assert.Equals(t, k.Detail, tc.adminErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
case assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) && tc.dbap == nil:
assert.Nil(t, dbp)
case assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr):
assert.Equals(t, dbp.ID, "ID")
assert.Equals(t, dbp.AuthorityID, tc.dbap.AuthorityID)
assert.Equals(t, dbp.Policy, tc.dbap.Policy)
}
})
}
}
func TestDB_CreateAuthorityPolicy(t *testing.T) {
authID := "authID"
type test struct {
ctx context.Context
authorityID string
policy *linkedca.Policy
db nosql.DB
err error
adminErr *admin.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/save-error": func(t *testing.T) test {
policy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
},
},
}
return test{
ctx: context.Background(),
authorityID: authID,
policy: policy,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
var _dbap = new(dbAuthorityPolicy)
assert.FatalError(t, json.Unmarshal(nu, _dbap))
assert.Equals(t, _dbap.ID, authID)
assert.Equals(t, _dbap.AuthorityID, authID)
assert.Equals(t, _dbap.Policy, policy)
return nil, false, errors.New("force")
},
},
adminErr: admin.NewErrorISE("error creating authority policy: error saving authority authority_policy: force"),
}
},
"ok": func(t *testing.T) test {
policy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
},
},
}
return test{
ctx: context.Background(),
authorityID: authID,
policy: policy,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, old, nil)
var _dbap = new(dbAuthorityPolicy)
assert.FatalError(t, json.Unmarshal(nu, _dbap))
assert.Equals(t, _dbap.ID, authID)
assert.Equals(t, _dbap.AuthorityID, authID)
assert.Equals(t, _dbap.Policy, policy)
return nil, true, nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db, authorityID: tc.authorityID}
if err := d.CreateAuthorityPolicy(tc.ctx, tc.policy); err != nil {
switch k := err.(type) {
case *admin.Error:
if assert.NotNil(t, tc.adminErr) {
assert.Equals(t, k.Type, tc.adminErr.Type)
assert.Equals(t, k.Detail, tc.adminErr.Detail)
assert.Equals(t, k.Status, tc.adminErr.Status)
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
assert.Equals(t, k.Detail, tc.adminErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
}
})
}
}
func TestDB_GetAuthorityPolicy(t *testing.T) {
authID := "authID"
type test struct {
ctx context.Context
authorityID string
policy *linkedca.Policy
db nosql.DB
err error
adminErr *admin.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return nil, nosqldb.ErrNotFound
},
},
adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"),
}
},
"fail/db.Get-error": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return nil, errors.New("force")
},
},
err: errors.New("error loading authority policy: force"),
}
},
"ok": func(t *testing.T) test {
policy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
},
},
}
return test{
ctx: context.Background(),
authorityID: authID,
policy: policy,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
dbap := &dbAuthorityPolicy{
ID: authID,
AuthorityID: authID,
Policy: policy,
}
b, err := json.Marshal(dbap)
assert.FatalError(t, err)
return b, nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db, authorityID: tc.authorityID}
got, err := d.GetAuthorityPolicy(tc.ctx)
if err != nil {
switch k := err.(type) {
case *admin.Error:
if assert.NotNil(t, tc.adminErr) {
assert.Equals(t, k.Type, tc.adminErr.Type)
assert.Equals(t, k.Detail, tc.adminErr.Detail)
assert.Equals(t, k.Status, tc.adminErr.Status)
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
assert.Equals(t, k.Detail, tc.adminErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
return
}
assert.NotNil(t, got)
assert.Equals(t, tc.policy, got)
})
}
}
func TestDB_UpdateAuthorityPolicy(t *testing.T) {
authID := "authID"
type test struct {
ctx context.Context
authorityID string
policy *linkedca.Policy
db nosql.DB
err error
adminErr *admin.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return nil, nosqldb.ErrNotFound
},
},
adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"),
}
},
"fail/db.Get-error": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return nil, errors.New("force")
},
},
err: errors.New("error loading authority policy: force"),
}
},
"fail/save-error": func(t *testing.T) test {
oldPolicy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.localhost"},
},
},
}
policy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
},
},
}
return test{
ctx: context.Background(),
authorityID: authID,
policy: policy,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
dbap := &dbAuthorityPolicy{
ID: authID,
AuthorityID: authID,
Policy: oldPolicy,
}
b, err := json.Marshal(dbap)
assert.FatalError(t, err)
return b, nil
},
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
var _dbap = new(dbAuthorityPolicy)
assert.FatalError(t, json.Unmarshal(nu, _dbap))
assert.Equals(t, _dbap.ID, authID)
assert.Equals(t, _dbap.AuthorityID, authID)
assert.Equals(t, _dbap.Policy, policy)
return nil, false, errors.New("force")
},
},
adminErr: admin.NewErrorISE("error updating authority policy: error saving authority authority_policy: force"),
}
},
"ok": func(t *testing.T) test {
oldPolicy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.localhost"},
},
},
}
policy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
},
},
}
return test{
ctx: context.Background(),
authorityID: authID,
policy: policy,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
dbap := &dbAuthorityPolicy{
ID: authID,
AuthorityID: authID,
Policy: oldPolicy,
}
b, err := json.Marshal(dbap)
assert.FatalError(t, err)
return b, nil
},
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
var _dbap = new(dbAuthorityPolicy)
assert.FatalError(t, json.Unmarshal(nu, _dbap))
assert.Equals(t, _dbap.ID, authID)
assert.Equals(t, _dbap.AuthorityID, authID)
assert.Equals(t, _dbap.Policy, policy)
return nil, true, nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db, authorityID: tc.authorityID}
if err := d.UpdateAuthorityPolicy(tc.ctx, tc.policy); err != nil {
switch k := err.(type) {
case *admin.Error:
if assert.NotNil(t, tc.adminErr) {
assert.Equals(t, k.Type, tc.adminErr.Type)
assert.Equals(t, k.Detail, tc.adminErr.Detail)
assert.Equals(t, k.Status, tc.adminErr.Status)
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
assert.Equals(t, k.Detail, tc.adminErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
return
}
})
}
}
func TestDB_DeleteAuthorityPolicy(t *testing.T) {
authID := "authID"
type test struct {
ctx context.Context
authorityID string
db nosql.DB
err error
adminErr *admin.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return nil, nosqldb.ErrNotFound
},
},
adminErr: admin.NewError(admin.ErrorNotFoundType, "authority policy not found"),
}
},
"fail/db.Get-error": func(t *testing.T) test {
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
return nil, errors.New("force")
},
},
err: errors.New("error loading authority policy: force"),
}
},
"fail/save-error": func(t *testing.T) test {
oldPolicy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.localhost"},
},
},
}
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
dbap := &dbAuthorityPolicy{
ID: authID,
AuthorityID: authID,
Policy: oldPolicy,
}
b, err := json.Marshal(dbap)
assert.FatalError(t, err)
return b, nil
},
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
assert.Equals(t, nil, nu)
return nil, false, errors.New("force")
},
},
adminErr: admin.NewErrorISE("error deleting authority policy: error saving authority authority_policy: force"),
}
},
"ok": func(t *testing.T) test {
oldPolicy := &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.localhost"},
},
},
}
return test{
ctx: context.Background(),
authorityID: authID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
dbap := &dbAuthorityPolicy{
ID: authID,
AuthorityID: authID,
Policy: oldPolicy,
}
b, err := json.Marshal(dbap)
assert.FatalError(t, err)
return b, nil
},
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, authorityPoliciesTable)
assert.Equals(t, string(key), authID)
assert.Equals(t, nil, nu)
return nil, true, nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
d := DB{db: tc.db, authorityID: tc.authorityID}
if err := d.DeleteAuthorityPolicy(tc.ctx); err != nil {
switch k := err.(type) {
case *admin.Error:
if assert.NotNil(t, tc.adminErr) {
assert.Equals(t, k.Type, tc.adminErr.Type)
assert.Equals(t, k.Detail, tc.adminErr.Detail)
assert.Equals(t, k.Status, tc.adminErr.Status)
assert.Equals(t, k.Err.Error(), tc.adminErr.Err.Error())
assert.Equals(t, k.Detail, tc.adminErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
return
}
})
}
}