certificates/db/db_test.go

295 lines
7.5 KiB
Go

package db
import (
"crypto/x509"
"errors"
"math/big"
"reflect"
"testing"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
)
func TestIsRevoked(t *testing.T) {
tests := map[string]struct {
key string
db *DB
isRevoked bool
err error
}{
"false/nil db": {
key: "sn",
},
"false/ErrNotFound": {
key: "sn",
db: &DB{&MockNoSQLDB{Err: database.ErrNotFound, Ret1: nil}, true},
},
"error/checking bucket": {
key: "sn",
db: &DB{&MockNoSQLDB{Err: errors.New("force"), Ret1: nil}, true},
err: errors.New("error checking revocation bucket: force"),
},
"true": {
key: "sn",
db: &DB{&MockNoSQLDB{Ret1: []byte("value")}, true},
isRevoked: true,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
isRevoked, err := tc.db.IsRevoked(tc.key)
if err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
assert.Fatal(t, isRevoked == tc.isRevoked)
}
})
}
}
func TestRevoke(t *testing.T) {
tests := map[string]struct {
rci *RevokedCertificateInfo
db *DB
err error
}{
"error/force isRevoked": {
rci: &RevokedCertificateInfo{Serial: "sn"},
db: &DB{&MockNoSQLDB{
MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
}, true},
err: errors.New("error AuthDB CmpAndSwap: force"),
},
"error/was already revoked": {
rci: &RevokedCertificateInfo{Serial: "sn"},
db: &DB{&MockNoSQLDB{
MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
return []byte("foo"), false, nil
},
}, true},
err: ErrAlreadyExists,
},
"ok": {
rci: &RevokedCertificateInfo{Serial: "sn"},
db: &DB{&MockNoSQLDB{
MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
return []byte("foo"), true, nil
},
}, true},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
if err := tc.db.Revoke(tc.rci); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, tc.err.Error(), err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestUseToken(t *testing.T) {
type result struct {
err error
ok bool
}
tests := map[string]struct {
id, tok string
db *DB
want result
}{
"fail/force-CmpAndSwap-error": {
id: "id",
tok: "token",
db: &DB{&MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
}, true},
want: result{
ok: false,
err: errors.New("error storing used token used_ott/id"),
},
},
"fail/CmpAndSwap-already-exists": {
id: "id",
tok: "token",
db: &DB{&MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return []byte("foo"), false, nil
},
}, true},
want: result{
ok: false,
},
},
"ok/cmpAndSwap-success": {
id: "id",
tok: "token",
db: &DB{&MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return []byte("bar"), true, nil
},
}, true},
want: result{
ok: true,
},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
switch ok, err := tc.db.UseToken(tc.id, tc.tok); {
case err != nil:
if assert.NotNil(t, tc.want.err) {
assert.HasPrefix(t, err.Error(), tc.want.err.Error())
}
assert.False(t, ok)
case ok:
assert.True(t, tc.want.ok)
default:
assert.False(t, tc.want.ok)
}
})
}
}
func TestDB_StoreCertificateChain(t *testing.T) {
p := &provisioner.JWK{
ID: "some-id",
Name: "admin",
Type: "JWK",
}
chain := []*x509.Certificate{
{Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)},
}
type fields struct {
DB nosql.DB
isUp bool
}
type args struct {
p provisioner.Interface
chain []*x509.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 2 {
t.Fatal("unexpected number of operations")
}
assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[0].Key)
assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value)
assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[1].Key)
assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), tx.Operations[1].Value)
return nil
},
}, true}, args{p, chain}, false},
{"ok no provisioner", fields{&MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
if len(tx.Operations) != 2 {
t.Fatal("unexpected number of operations")
}
assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[0].Key)
assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value)
assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket)
assert.Equals(t, []byte("1234"), tx.Operations[1].Key)
assert.Equals(t, []byte(`{}`), tx.Operations[1].Value)
return nil
},
}, true}, args{nil, chain}, false},
{"fail store certificate", fields{&MockNoSQLDB{
MUpdate: func(tx *database.Tx) error {
return errors.New("test error")
},
}, true}, args{p, chain}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := &DB{
DB: tt.fields.DB,
isUp: tt.fields.isUp,
}
if err := d.StoreCertificateChain(tt.args.p, tt.args.chain...); (err != nil) != tt.wantErr {
t.Errorf("DB.StoreCertificateChain() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestDB_GetCertificateData(t *testing.T) {
type fields struct {
DB nosql.DB
isUp bool
}
type args struct {
serialNumber string
}
tests := []struct {
name string
fields fields
args args
want *CertificateData
wantErr bool
}{
{"ok", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, []byte("x509_certs_data"))
assert.Equals(t, key, []byte("1234"))
return []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), nil
},
}, true}, args{"1234"}, &CertificateData{
Provisioner: &ProvisionerData{
ID: "some-id", Name: "admin", Type: "JWK",
},
}, false},
{"fail not found", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
}, true}, args{"1234"}, nil, true},
{"fail db", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("an error")
},
}, true}, args{"1234"}, nil, true},
{"fail unmarshal", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return []byte(`{"bad-json"}`), nil
},
}, true}, args{"1234"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &DB{
DB: tt.fields.DB,
isUp: tt.fields.isUp,
}
got, err := db.GetCertificateData(tt.args.serialNumber)
if (err != nil) != tt.wantErr {
t.Errorf("DB.GetCertificateData() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("DB.GetCertificateData() = %v, want %v", got, tt.want)
}
})
}
}