Add GetCertificateData and refactor x509_certs_data.

This commit is contained in:
Mariano Cano 2022-04-05 19:24:53 -07:00
parent 41c6ded85e
commit 7d6116c3d0
2 changed files with 110 additions and 23 deletions

View file

@ -16,7 +16,7 @@ import (
var ( var (
certsTable = []byte("x509_certs") certsTable = []byte("x509_certs")
certsToProvisionerTable = []byte("x509_certs_provisioner") certsDataTable = []byte("x509_certs_data")
revokedCertsTable = []byte("revoked_x509_certs") revokedCertsTable = []byte("revoked_x509_certs")
revokedSSHCertsTable = []byte("revoked_ssh_certs") revokedSSHCertsTable = []byte("revoked_ssh_certs")
usedOTTTable = []byte("used_ott") usedOTTTable = []byte("used_ott")
@ -84,7 +84,7 @@ func New(c *Config) (AuthDB, error) {
tables := [][]byte{ tables := [][]byte{
revokedCertsTable, certsTable, usedOTTTable, revokedCertsTable, certsTable, usedOTTTable,
sshCertsTable, sshHostsTable, sshHostPrincipalsTable, sshUsersTable, sshCertsTable, sshHostsTable, sshHostPrincipalsTable, sshUsersTable,
revokedSSHCertsTable, certsToProvisionerTable, revokedSSHCertsTable, certsDataTable,
} }
for _, b := range tables { for _, b := range tables {
if err := db.CreateTable(b); err != nil { if err := db.CreateTable(b); err != nil {
@ -204,6 +204,19 @@ func (db *DB) GetCertificate(serialNumber string) (*x509.Certificate, error) {
return cert, nil return cert, nil
} }
// GetCertificateData returns the data stored for a provisioner
func (db *DB) GetCertificateData(serialNumber string) (*CertificateData, error) {
b, err := db.Get(certsDataTable, []byte(serialNumber))
if err != nil {
return nil, errors.Wrap(err, "database Get error")
}
var data CertificateData
if err := json.Unmarshal(b, &data); err != nil {
return nil, errors.Wrap(err, "error unmarshaling json")
}
return &data, nil
}
// StoreCertificate stores a certificate PEM. // StoreCertificate stores a certificate PEM.
func (db *DB) StoreCertificate(crt *x509.Certificate) error { func (db *DB) StoreCertificate(crt *x509.Certificate) error {
if err := db.Set(certsTable, []byte(crt.SerialNumber.String()), crt.Raw); err != nil { if err := db.Set(certsTable, []byte(crt.SerialNumber.String()), crt.Raw); err != nil {
@ -212,7 +225,15 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
return nil return nil
} }
type certsToProvionersData struct { // CertificateData is the JSON representation of the data stored in
// x509_certs_data table.
type CertificateData struct {
Provisioner *ProvisionerData `json:"provisioner,omitempty"`
}
// ProvisionerData is the JSON representation of the provisioner stored in the
// x509_certs_data table.
type ProvisionerData struct {
ID string `json:"id"` ID string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Type string `json:"type"` Type string `json:"type"`
@ -220,25 +241,27 @@ type certsToProvionersData struct {
// StoreCertificateChain stores the leaf certificate and the provisioner that // StoreCertificateChain stores the leaf certificate and the provisioner that
// authorized the certificate. // authorized the certificate.
func (d *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error { func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error {
leaf := chain[0] leaf := chain[0]
if err := d.StoreCertificate(leaf); err != nil { if err := db.StoreCertificate(leaf); err != nil {
return err return err
} }
data := &CertificateData{}
if p != nil { if p != nil {
b, err := json.Marshal(certsToProvionersData{ data.Provisioner = &ProvisionerData{
ID: p.GetID(), ID: p.GetID(),
Name: p.GetName(), Name: p.GetName(),
Type: p.GetType().String(), Type: p.GetType().String(),
}) }
}
b, err := json.Marshal(data)
if err != nil { if err != nil {
return errors.Wrap(err, "error marshaling json") return errors.Wrap(err, "error marshaling json")
} }
if err := db.Set(certsDataTable, []byte(leaf.SerialNumber.String()), b); err != nil {
if err := d.Set(certsToProvisionerTable, []byte(leaf.SerialNumber.String()), b); err != nil {
return errors.Wrap(err, "database Set error") return errors.Wrap(err, "database Set error")
} }
}
return nil return nil
} }

View file

@ -4,6 +4,7 @@ import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"math/big" "math/big"
"reflect"
"testing" "testing"
"github.com/smallstep/assert" "github.com/smallstep/assert"
@ -192,9 +193,9 @@ func TestDB_StoreCertificateChain(t *testing.T) {
case "x509_certs": case "x509_certs":
assert.Equals(t, key, []byte("1234")) assert.Equals(t, key, []byte("1234"))
assert.Equals(t, value, []byte("the certificate")) assert.Equals(t, value, []byte("the certificate"))
case "x509_certs_provisioner": case "x509_certs_data":
assert.Equals(t, key, []byte("1234")) assert.Equals(t, key, []byte("1234"))
assert.Equals(t, value, []byte(`{"id":"some-id","name":"admin","type":"JWK"}`)) assert.Equals(t, value, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`))
default: default:
t.Errorf("unexpected bucket %s", bucket) t.Errorf("unexpected bucket %s", bucket)
} }
@ -207,6 +208,9 @@ func TestDB_StoreCertificateChain(t *testing.T) {
case "x509_certs": case "x509_certs":
assert.Equals(t, key, []byte("1234")) assert.Equals(t, key, []byte("1234"))
assert.Equals(t, value, []byte("the certificate")) assert.Equals(t, value, []byte("the certificate"))
case "x509_certs_data":
assert.Equals(t, key, []byte("1234"))
assert.Equals(t, value, []byte(`{}`))
default: default:
t.Errorf("unexpected bucket %s", bucket) t.Errorf("unexpected bucket %s", bucket)
} }
@ -226,7 +230,7 @@ func TestDB_StoreCertificateChain(t *testing.T) {
{"fail store provisioner", fields{&MockNoSQLDB{ {"fail store provisioner", fields{&MockNoSQLDB{
MSet: func(bucket, key, value []byte) error { MSet: func(bucket, key, value []byte) error {
switch string(bucket) { switch string(bucket) {
case "x509_certs_provisioner": case "x509_certs_data":
return errors.New("test error") return errors.New("test error")
default: default:
return nil return nil
@ -246,3 +250,63 @@ func TestDB_StoreCertificateChain(t *testing.T) {
}) })
} }
} }
func TestDB_GetCertificateData(t *testing.T) {
type fields struct {
DB nosql.DB
isUp bool
}
type args struct {
serialNumber string
}
tests := []struct {
name string
fields fields
args args
want *CertificateData
wantErr bool
}{
{"ok", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, []byte("x509_certs_data"))
assert.Equals(t, key, []byte("1234"))
return []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), nil
},
}, true}, args{"1234"}, &CertificateData{
Provisioner: &ProvisionerData{
ID: "some-id", Name: "admin", Type: "JWK",
},
}, false},
{"fail not found", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
}, true}, args{"1234"}, nil, true},
{"fail db", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("an error")
},
}, true}, args{"1234"}, nil, true},
{"fail unmarshal", fields{&MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return []byte(`{"bad-json"}`), nil
},
}, true}, args{"1234"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &DB{
DB: tt.fields.DB,
isUp: tt.fields.isUp,
}
got, err := db.GetCertificateData(tt.args.serialNumber)
if (err != nil) != tt.wantErr {
t.Errorf("DB.GetCertificateData() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("DB.GetCertificateData() = %v, want %v", got, tt.want)
}
})
}
}