From 7d6116c3d052bef1f2e723e67e177280b621f3c4 Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 5 Apr 2022 19:24:53 -0700 Subject: [PATCH] Add GetCertificateData and refactor x509_certs_data. --- db/db.go | 63 +++++++++++++++++++++++++++++++--------------- db/db_test.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 110 insertions(+), 23 deletions(-) diff --git a/db/db.go b/db/db.go index 3427d2bb..a3ebb19f 100644 --- a/db/db.go +++ b/db/db.go @@ -15,15 +15,15 @@ import ( ) var ( - certsTable = []byte("x509_certs") - certsToProvisionerTable = []byte("x509_certs_provisioner") - revokedCertsTable = []byte("revoked_x509_certs") - revokedSSHCertsTable = []byte("revoked_ssh_certs") - usedOTTTable = []byte("used_ott") - sshCertsTable = []byte("ssh_certs") - sshHostsTable = []byte("ssh_hosts") - sshUsersTable = []byte("ssh_users") - sshHostPrincipalsTable = []byte("ssh_host_principals") + certsTable = []byte("x509_certs") + certsDataTable = []byte("x509_certs_data") + revokedCertsTable = []byte("revoked_x509_certs") + revokedSSHCertsTable = []byte("revoked_ssh_certs") + usedOTTTable = []byte("used_ott") + sshCertsTable = []byte("ssh_certs") + sshHostsTable = []byte("ssh_hosts") + sshUsersTable = []byte("ssh_users") + sshHostPrincipalsTable = []byte("ssh_host_principals") ) // ErrAlreadyExists can be returned if the DB attempts to set a key that has @@ -84,7 +84,7 @@ func New(c *Config) (AuthDB, error) { tables := [][]byte{ revokedCertsTable, certsTable, usedOTTTable, sshCertsTable, sshHostsTable, sshHostPrincipalsTable, sshUsersTable, - revokedSSHCertsTable, certsToProvisionerTable, + revokedSSHCertsTable, certsDataTable, } for _, b := range tables { if err := db.CreateTable(b); err != nil { @@ -204,6 +204,19 @@ func (db *DB) GetCertificate(serialNumber string) (*x509.Certificate, error) { return cert, nil } +// GetCertificateData returns the data stored for a provisioner +func (db *DB) GetCertificateData(serialNumber string) (*CertificateData, error) { + b, err := db.Get(certsDataTable, []byte(serialNumber)) + if err != nil { + return nil, errors.Wrap(err, "database Get error") + } + var data CertificateData + if err := json.Unmarshal(b, &data); err != nil { + return nil, errors.Wrap(err, "error unmarshaling json") + } + return &data, nil +} + // StoreCertificate stores a certificate PEM. func (db *DB) StoreCertificate(crt *x509.Certificate) error { if err := db.Set(certsTable, []byte(crt.SerialNumber.String()), crt.Raw); err != nil { @@ -212,7 +225,15 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error { return nil } -type certsToProvionersData struct { +// CertificateData is the JSON representation of the data stored in +// x509_certs_data table. +type CertificateData struct { + Provisioner *ProvisionerData `json:"provisioner,omitempty"` +} + +// ProvisionerData is the JSON representation of the provisioner stored in the +// x509_certs_data table. +type ProvisionerData struct { ID string `json:"id"` Name string `json:"name"` Type string `json:"type"` @@ -220,24 +241,26 @@ type certsToProvionersData struct { // StoreCertificateChain stores the leaf certificate and the provisioner that // authorized the certificate. -func (d *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error { +func (db *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error { leaf := chain[0] - if err := d.StoreCertificate(leaf); err != nil { + if err := db.StoreCertificate(leaf); err != nil { return err } + data := &CertificateData{} if p != nil { - b, err := json.Marshal(certsToProvionersData{ + data.Provisioner = &ProvisionerData{ ID: p.GetID(), Name: p.GetName(), Type: p.GetType().String(), - }) - if err != nil { - return errors.Wrap(err, "error marshaling json") } + } - if err := d.Set(certsToProvisionerTable, []byte(leaf.SerialNumber.String()), b); err != nil { - return errors.Wrap(err, "database Set error") - } + b, err := json.Marshal(data) + if err != nil { + return errors.Wrap(err, "error marshaling json") + } + if err := db.Set(certsDataTable, []byte(leaf.SerialNumber.String()), b); err != nil { + return errors.Wrap(err, "database Set error") } return nil } diff --git a/db/db_test.go b/db/db_test.go index 5a7e2d38..d7c58c9c 100644 --- a/db/db_test.go +++ b/db/db_test.go @@ -4,6 +4,7 @@ import ( "crypto/x509" "errors" "math/big" + "reflect" "testing" "github.com/smallstep/assert" @@ -192,9 +193,9 @@ func TestDB_StoreCertificateChain(t *testing.T) { case "x509_certs": assert.Equals(t, key, []byte("1234")) assert.Equals(t, value, []byte("the certificate")) - case "x509_certs_provisioner": + case "x509_certs_data": assert.Equals(t, key, []byte("1234")) - assert.Equals(t, value, []byte(`{"id":"some-id","name":"admin","type":"JWK"}`)) + assert.Equals(t, value, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`)) default: t.Errorf("unexpected bucket %s", bucket) } @@ -207,6 +208,9 @@ func TestDB_StoreCertificateChain(t *testing.T) { case "x509_certs": assert.Equals(t, key, []byte("1234")) assert.Equals(t, value, []byte("the certificate")) + case "x509_certs_data": + assert.Equals(t, key, []byte("1234")) + assert.Equals(t, value, []byte(`{}`)) default: t.Errorf("unexpected bucket %s", bucket) } @@ -226,7 +230,7 @@ func TestDB_StoreCertificateChain(t *testing.T) { {"fail store provisioner", fields{&MockNoSQLDB{ MSet: func(bucket, key, value []byte) error { switch string(bucket) { - case "x509_certs_provisioner": + case "x509_certs_data": return errors.New("test error") default: return nil @@ -246,3 +250,63 @@ func TestDB_StoreCertificateChain(t *testing.T) { }) } } + +func TestDB_GetCertificateData(t *testing.T) { + type fields struct { + DB nosql.DB + isUp bool + } + type args struct { + serialNumber string + } + tests := []struct { + name string + fields fields + args args + want *CertificateData + wantErr bool + }{ + {"ok", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + assert.Equals(t, bucket, []byte("x509_certs_data")) + assert.Equals(t, key, []byte("1234")) + return []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), nil + }, + }, true}, args{"1234"}, &CertificateData{ + Provisioner: &ProvisionerData{ + ID: "some-id", Name: "admin", Type: "JWK", + }, + }, false}, + {"fail not found", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, database.ErrNotFound + }, + }, true}, args{"1234"}, nil, true}, + {"fail db", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return nil, errors.New("an error") + }, + }, true}, args{"1234"}, nil, true}, + {"fail unmarshal", fields{&MockNoSQLDB{ + MGet: func(bucket, key []byte) ([]byte, error) { + return []byte(`{"bad-json"}`), nil + }, + }, true}, args{"1234"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := &DB{ + DB: tt.fields.DB, + isUp: tt.fields.isUp, + } + got, err := db.GetCertificateData(tt.args.serialNumber) + if (err != nil) != tt.wantErr { + t.Errorf("DB.GetCertificateData() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("DB.GetCertificateData() = %v, want %v", got, tt.want) + } + }) + } +}