Add GetCertificateData and refactor x509_certs_data.
This commit is contained in:
parent
41c6ded85e
commit
7d6116c3d0
2 changed files with 110 additions and 23 deletions
63
db/db.go
63
db/db.go
|
@ -15,15 +15,15 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
certsTable = []byte("x509_certs")
|
certsTable = []byte("x509_certs")
|
||||||
certsToProvisionerTable = []byte("x509_certs_provisioner")
|
certsDataTable = []byte("x509_certs_data")
|
||||||
revokedCertsTable = []byte("revoked_x509_certs")
|
revokedCertsTable = []byte("revoked_x509_certs")
|
||||||
revokedSSHCertsTable = []byte("revoked_ssh_certs")
|
revokedSSHCertsTable = []byte("revoked_ssh_certs")
|
||||||
usedOTTTable = []byte("used_ott")
|
usedOTTTable = []byte("used_ott")
|
||||||
sshCertsTable = []byte("ssh_certs")
|
sshCertsTable = []byte("ssh_certs")
|
||||||
sshHostsTable = []byte("ssh_hosts")
|
sshHostsTable = []byte("ssh_hosts")
|
||||||
sshUsersTable = []byte("ssh_users")
|
sshUsersTable = []byte("ssh_users")
|
||||||
sshHostPrincipalsTable = []byte("ssh_host_principals")
|
sshHostPrincipalsTable = []byte("ssh_host_principals")
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
|
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
|
||||||
|
@ -84,7 +84,7 @@ func New(c *Config) (AuthDB, error) {
|
||||||
tables := [][]byte{
|
tables := [][]byte{
|
||||||
revokedCertsTable, certsTable, usedOTTTable,
|
revokedCertsTable, certsTable, usedOTTTable,
|
||||||
sshCertsTable, sshHostsTable, sshHostPrincipalsTable, sshUsersTable,
|
sshCertsTable, sshHostsTable, sshHostPrincipalsTable, sshUsersTable,
|
||||||
revokedSSHCertsTable, certsToProvisionerTable,
|
revokedSSHCertsTable, certsDataTable,
|
||||||
}
|
}
|
||||||
for _, b := range tables {
|
for _, b := range tables {
|
||||||
if err := db.CreateTable(b); err != nil {
|
if err := db.CreateTable(b); err != nil {
|
||||||
|
@ -204,6 +204,19 @@ func (db *DB) GetCertificate(serialNumber string) (*x509.Certificate, error) {
|
||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCertificateData returns the data stored for a provisioner
|
||||||
|
func (db *DB) GetCertificateData(serialNumber string) (*CertificateData, error) {
|
||||||
|
b, err := db.Get(certsDataTable, []byte(serialNumber))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "database Get error")
|
||||||
|
}
|
||||||
|
var data CertificateData
|
||||||
|
if err := json.Unmarshal(b, &data); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error unmarshaling json")
|
||||||
|
}
|
||||||
|
return &data, nil
|
||||||
|
}
|
||||||
|
|
||||||
// StoreCertificate stores a certificate PEM.
|
// StoreCertificate stores a certificate PEM.
|
||||||
func (db *DB) StoreCertificate(crt *x509.Certificate) error {
|
func (db *DB) StoreCertificate(crt *x509.Certificate) error {
|
||||||
if err := db.Set(certsTable, []byte(crt.SerialNumber.String()), crt.Raw); err != nil {
|
if err := db.Set(certsTable, []byte(crt.SerialNumber.String()), crt.Raw); err != nil {
|
||||||
|
@ -212,7 +225,15 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type certsToProvionersData struct {
|
// CertificateData is the JSON representation of the data stored in
|
||||||
|
// x509_certs_data table.
|
||||||
|
type CertificateData struct {
|
||||||
|
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProvisionerData is the JSON representation of the provisioner stored in the
|
||||||
|
// x509_certs_data table.
|
||||||
|
type ProvisionerData struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
@ -220,24 +241,26 @@ type certsToProvionersData struct {
|
||||||
|
|
||||||
// StoreCertificateChain stores the leaf certificate and the provisioner that
|
// StoreCertificateChain stores the leaf certificate and the provisioner that
|
||||||
// authorized the certificate.
|
// authorized the certificate.
|
||||||
func (d *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error {
|
func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error {
|
||||||
leaf := chain[0]
|
leaf := chain[0]
|
||||||
if err := d.StoreCertificate(leaf); err != nil {
|
if err := db.StoreCertificate(leaf); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
data := &CertificateData{}
|
||||||
if p != nil {
|
if p != nil {
|
||||||
b, err := json.Marshal(certsToProvionersData{
|
data.Provisioner = &ProvisionerData{
|
||||||
ID: p.GetID(),
|
ID: p.GetID(),
|
||||||
Name: p.GetName(),
|
Name: p.GetName(),
|
||||||
Type: p.GetType().String(),
|
Type: p.GetType().String(),
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "error marshaling json")
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := d.Set(certsToProvisionerTable, []byte(leaf.SerialNumber.String()), b); err != nil {
|
b, err := json.Marshal(data)
|
||||||
return errors.Wrap(err, "database Set error")
|
if err != nil {
|
||||||
}
|
return errors.Wrap(err, "error marshaling json")
|
||||||
|
}
|
||||||
|
if err := db.Set(certsDataTable, []byte(leaf.SerialNumber.String()), b); err != nil {
|
||||||
|
return errors.Wrap(err, "database Set error")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
|
@ -192,9 +193,9 @@ func TestDB_StoreCertificateChain(t *testing.T) {
|
||||||
case "x509_certs":
|
case "x509_certs":
|
||||||
assert.Equals(t, key, []byte("1234"))
|
assert.Equals(t, key, []byte("1234"))
|
||||||
assert.Equals(t, value, []byte("the certificate"))
|
assert.Equals(t, value, []byte("the certificate"))
|
||||||
case "x509_certs_provisioner":
|
case "x509_certs_data":
|
||||||
assert.Equals(t, key, []byte("1234"))
|
assert.Equals(t, key, []byte("1234"))
|
||||||
assert.Equals(t, value, []byte(`{"id":"some-id","name":"admin","type":"JWK"}`))
|
assert.Equals(t, value, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`))
|
||||||
default:
|
default:
|
||||||
t.Errorf("unexpected bucket %s", bucket)
|
t.Errorf("unexpected bucket %s", bucket)
|
||||||
}
|
}
|
||||||
|
@ -207,6 +208,9 @@ func TestDB_StoreCertificateChain(t *testing.T) {
|
||||||
case "x509_certs":
|
case "x509_certs":
|
||||||
assert.Equals(t, key, []byte("1234"))
|
assert.Equals(t, key, []byte("1234"))
|
||||||
assert.Equals(t, value, []byte("the certificate"))
|
assert.Equals(t, value, []byte("the certificate"))
|
||||||
|
case "x509_certs_data":
|
||||||
|
assert.Equals(t, key, []byte("1234"))
|
||||||
|
assert.Equals(t, value, []byte(`{}`))
|
||||||
default:
|
default:
|
||||||
t.Errorf("unexpected bucket %s", bucket)
|
t.Errorf("unexpected bucket %s", bucket)
|
||||||
}
|
}
|
||||||
|
@ -226,7 +230,7 @@ func TestDB_StoreCertificateChain(t *testing.T) {
|
||||||
{"fail store provisioner", fields{&MockNoSQLDB{
|
{"fail store provisioner", fields{&MockNoSQLDB{
|
||||||
MSet: func(bucket, key, value []byte) error {
|
MSet: func(bucket, key, value []byte) error {
|
||||||
switch string(bucket) {
|
switch string(bucket) {
|
||||||
case "x509_certs_provisioner":
|
case "x509_certs_data":
|
||||||
return errors.New("test error")
|
return errors.New("test error")
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
|
@ -246,3 +250,63 @@ func TestDB_StoreCertificateChain(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDB_GetCertificateData(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
DB nosql.DB
|
||||||
|
isUp bool
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
serialNumber string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
want *CertificateData
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{&MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
assert.Equals(t, bucket, []byte("x509_certs_data"))
|
||||||
|
assert.Equals(t, key, []byte("1234"))
|
||||||
|
return []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), nil
|
||||||
|
},
|
||||||
|
}, true}, args{"1234"}, &CertificateData{
|
||||||
|
Provisioner: &ProvisionerData{
|
||||||
|
ID: "some-id", Name: "admin", Type: "JWK",
|
||||||
|
},
|
||||||
|
}, false},
|
||||||
|
{"fail not found", fields{&MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, database.ErrNotFound
|
||||||
|
},
|
||||||
|
}, true}, args{"1234"}, nil, true},
|
||||||
|
{"fail db", fields{&MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return nil, errors.New("an error")
|
||||||
|
},
|
||||||
|
}, true}, args{"1234"}, nil, true},
|
||||||
|
{"fail unmarshal", fields{&MockNoSQLDB{
|
||||||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
|
return []byte(`{"bad-json"}`), nil
|
||||||
|
},
|
||||||
|
}, true}, args{"1234"}, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
db := &DB{
|
||||||
|
DB: tt.fields.DB,
|
||||||
|
isUp: tt.fields.isUp,
|
||||||
|
}
|
||||||
|
got, err := db.GetCertificateData(tt.args.serialNumber)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("DB.GetCertificateData() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("DB.GetCertificateData() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue