Store in the db the provisioner that granted a cert.

This commit is contained in:
Mariano Cano 2022-04-05 18:00:01 -07:00
parent df8ffb35af
commit 41c6ded85e
2 changed files with 129 additions and 9 deletions

View file

@ -8,6 +8,7 @@ import (
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
"golang.org/x/crypto/ssh"
@ -15,6 +16,7 @@ import (
var (
certsTable = []byte("x509_certs")
certsToProvisionerTable = []byte("x509_certs_provisioner")
revokedCertsTable = []byte("revoked_x509_certs")
revokedSSHCertsTable = []byte("revoked_ssh_certs")
usedOTTTable = []byte("used_ott")
@ -82,7 +84,7 @@ func New(c *Config) (AuthDB, error) {
tables := [][]byte{
revokedCertsTable, certsTable, usedOTTTable,
sshCertsTable, sshHostsTable, sshHostPrincipalsTable, sshUsersTable,
revokedSSHCertsTable,
revokedSSHCertsTable, certsToProvisionerTable,
}
for _, b := range tables {
if err := db.CreateTable(b); err != nil {
@ -210,6 +212,36 @@ func (db *DB) StoreCertificate(crt *x509.Certificate) error {
return nil
}
type certsToProvionersData struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
}
// StoreCertificateChain stores the leaf certificate and the provisioner that
// authorized the certificate.
func (d *DB) StoreCertificateChain(p provisioner.Interface, chain ...*x509.Certificate) error {
leaf := chain[0]
if err := d.StoreCertificate(leaf); err != nil {
return err
}
if p != nil {
b, err := json.Marshal(certsToProvionersData{
ID: p.GetID(),
Name: p.GetName(),
Type: p.GetType().String(),
})
if err != nil {
return errors.Wrap(err, "error marshaling json")
}
if err := d.Set(certsToProvisionerTable, []byte(leaf.SerialNumber.String()), b); err != nil {
return errors.Wrap(err, "database Set error")
}
}
return nil
}
// UseToken returns true if we were able to successfully store the token for
// for the first time, false otherwise.
func (db *DB) UseToken(id, tok string) (bool, error) {

View file

@ -1,10 +1,14 @@
package db
import (
"crypto/x509"
"errors"
"math/big"
"testing"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
)
@ -158,3 +162,87 @@ func TestUseToken(t *testing.T) {
})
}
}
func TestDB_StoreCertificateChain(t *testing.T) {
p := &provisioner.JWK{
ID: "some-id",
Name: "admin",
Type: "JWK",
}
chain := []*x509.Certificate{
{Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)},
}
type fields struct {
DB nosql.DB
isUp bool
}
type args struct {
p provisioner.Interface
chain []*x509.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&MockNoSQLDB{
MSet: func(bucket, key, value []byte) error {
switch string(bucket) {
case "x509_certs":
assert.Equals(t, key, []byte("1234"))
assert.Equals(t, value, []byte("the certificate"))
case "x509_certs_provisioner":
assert.Equals(t, key, []byte("1234"))
assert.Equals(t, value, []byte(`{"id":"some-id","name":"admin","type":"JWK"}`))
default:
t.Errorf("unexpected bucket %s", bucket)
}
return nil
},
}, true}, args{p, chain}, false},
{"ok no provisioner", fields{&MockNoSQLDB{
MSet: func(bucket, key, value []byte) error {
switch string(bucket) {
case "x509_certs":
assert.Equals(t, key, []byte("1234"))
assert.Equals(t, value, []byte("the certificate"))
default:
t.Errorf("unexpected bucket %s", bucket)
}
return nil
},
}, true}, args{nil, chain}, false},
{"fail store certificate", fields{&MockNoSQLDB{
MSet: func(bucket, key, value []byte) error {
switch string(bucket) {
case "x509_certs":
return errors.New("test error")
default:
return nil
}
},
}, true}, args{p, chain}, true},
{"fail store provisioner", fields{&MockNoSQLDB{
MSet: func(bucket, key, value []byte) error {
switch string(bucket) {
case "x509_certs_provisioner":
return errors.New("test error")
default:
return nil
}
},
}, true}, args{p, chain}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := &DB{
DB: tt.fields.DB,
isUp: tt.fields.isUp,
}
if err := d.StoreCertificateChain(tt.args.p, tt.args.chain...); (err != nil) != tt.wantErr {
t.Errorf("DB.StoreCertificateChain() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}