certificates/db/db.go

374 lines
9.8 KiB
Go

package db
import (
"crypto/x509"
"encoding/json"
"strconv"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
"golang.org/x/crypto/ssh"
)
var (
certsTable = []byte("x509_certs")
revokedCertsTable = []byte("revoked_x509_certs")
revokedSSHCertsTable = []byte("revoked_ssh_certs")
usedOTTTable = []byte("used_ott")
sshCertsTable = []byte("ssh_certs")
sshHostsTable = []byte("ssh_hosts")
sshUsersTable = []byte("ssh_users")
sshHostPrincipalsTable = []byte("ssh_host_principals")
)
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
// been previously set.
var ErrAlreadyExists = errors.New("already exists")
// Config represents the JSON attributes used for configuring a step-ca DB.
type Config struct {
Type string `json:"type"`
DataSource string `json:"dataSource"`
ValueDir string `json:"valueDir,omitempty"`
Database string `json:"database,omitempty"`
}
// AuthDB is an interface over an Authority DB client that implements a nosql.DB interface.
type AuthDB interface {
IsRevoked(sn string) (bool, error)
IsSSHRevoked(sn string) (bool, error)
Revoke(rci *RevokedCertificateInfo) error
RevokeSSH(rci *RevokedCertificateInfo) error
StoreCertificate(crt *x509.Certificate) error
UseToken(id, tok string) (bool, error)
IsSSHHost(name string) (bool, error)
StoreSSHCertificate(crt *ssh.Certificate) error
GetSSHHostPrincipals() ([]string, error)
Shutdown() error
}
// DB is a wrapper over the nosql.DB interface.
type DB struct {
nosql.DB
isUp bool
}
// New returns a new database client that implements the AuthDB interface.
func New(c *Config) (AuthDB, error) {
if c == nil {
return newSimpleDB(c)
}
db, err := nosql.New(c.Type, c.DataSource, nosql.WithDatabase(c.Database),
nosql.WithValueDir(c.ValueDir))
if err != nil {
return nil, errors.Wrapf(err, "Error opening database of Type %s with source %s", c.Type, c.DataSource)
}
tables := [][]byte{
revokedCertsTable, certsTable, usedOTTTable,
sshCertsTable, sshHostsTable, sshHostPrincipalsTable, sshUsersTable,
revokedSSHCertsTable,
}
for _, b := range tables {
if err := db.CreateTable(b); err != nil {
return nil, errors.Wrapf(err, "error creating table %s",
string(b))
}
}
return &DB{db, true}, nil
}
// RevokedCertificateInfo contains information regarding the certificate
// revocation action.
type RevokedCertificateInfo struct {
Serial string
ProvisionerID string
ReasonCode int
Reason string
RevokedAt time.Time
TokenID string
MTLS bool
}
// IsRevoked returns whether or not a certificate with the given identifier
// has been revoked.
// In the case of an X509 Certificate the `id` should be the Serial Number of
// the Certificate.
func (db *DB) IsRevoked(sn string) (bool, error) {
// If the DB is nil then act as pass through.
if db == nil {
return false, nil
}
// If the error is `Not Found` then the certificate has not been revoked.
// Any other error should be propagated to the caller.
if _, err := db.Get(revokedCertsTable, []byte(sn)); err != nil {
if nosql.IsErrNotFound(err) {
return false, nil
}
return false, errors.Wrap(err, "error checking revocation bucket")
}
// This certificate has been revoked.
return true, nil
}
// IsSSHRevoked returns whether or not a certificate with the given identifier
// has been revoked.
// In the case of an X509 Certificate the `id` should be the Serial Number of
// the Certificate.
func (db *DB) IsSSHRevoked(sn string) (bool, error) {
// If the DB is nil then act as pass through.
if db == nil {
return false, nil
}
// If the error is `Not Found` then the certificate has not been revoked.
// Any other error should be propagated to the caller.
if _, err := db.Get(revokedSSHCertsTable, []byte(sn)); err != nil {
if nosql.IsErrNotFound(err) {
return false, nil
}
return false, errors.Wrap(err, "error checking revocation bucket")
}
// This certificate has been revoked.
return true, nil
}
// Revoke adds a certificate to the revocation table.
func (db *DB) Revoke(rci *RevokedCertificateInfo) error {
rcib, err := json.Marshal(rci)
if err != nil {
return errors.Wrap(err, "error marshaling revoked certificate info")
}
_, swapped, err := db.CmpAndSwap(revokedCertsTable, []byte(rci.Serial), nil, rcib)
switch {
case err != nil:
return errors.Wrap(err, "error AuthDB CmpAndSwap")
case !swapped:
return ErrAlreadyExists
default:
return nil
}
}
// RevokeSSH adds a SSH certificate to the revocation table.
func (db *DB) RevokeSSH(rci *RevokedCertificateInfo) error {
rcib, err := json.Marshal(rci)
if err != nil {
return errors.Wrap(err, "error marshaling revoked certificate info")
}
_, swapped, err := db.CmpAndSwap(revokedSSHCertsTable, []byte(rci.Serial), nil, rcib)
switch {
case err != nil:
return errors.Wrap(err, "error AuthDB CmpAndSwap")
case !swapped:
return ErrAlreadyExists
default:
return nil
}
}
// StoreCertificate stores a certificate PEM.
func (db *DB) StoreCertificate(crt *x509.Certificate) error {
if err := db.Set(certsTable, []byte(crt.SerialNumber.String()), crt.Raw); err != nil {
return errors.Wrap(err, "database Set error")
}
return nil
}
// UseToken returns true if we were able to successfully store the token for
// for the first time, false otherwise.
func (db *DB) UseToken(id, tok string) (bool, error) {
_, swapped, err := db.CmpAndSwap(usedOTTTable, []byte(id), nil, []byte(tok))
if err != nil {
return false, errors.Wrapf(err, "error storing used token %s/%s",
string(usedOTTTable), id)
}
return swapped, nil
}
// IsSSHHost returns if a principal is present in the ssh hosts table.
func (db *DB) IsSSHHost(principal string) (bool, error) {
if _, err := db.Get(sshHostsTable, []byte(strings.ToLower(principal))); err != nil {
if database.IsErrNotFound(err) {
return false, nil
}
return false, errors.Wrap(err, "database Get error")
}
return true, nil
}
type sshHostPrincipalData struct {
Serial string
Expiry uint64
}
// StoreSSHCertificate stores an SSH certificate.
func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error {
serial := strconv.FormatUint(crt.Serial, 10)
tx := new(database.Tx)
tx.Set(sshCertsTable, []byte(serial), crt.Marshal())
if crt.CertType == ssh.HostCert {
for _, p := range crt.ValidPrincipals {
hostPrincipalData, err := json.Marshal(sshHostPrincipalData{
Serial: serial,
Expiry: crt.ValidBefore,
})
if err != nil {
return err
}
tx.Set(sshHostsTable, []byte(strings.ToLower(p)), []byte(serial))
tx.Set(sshHostPrincipalsTable, []byte(strings.ToLower(p)), hostPrincipalData)
}
} else {
for _, p := range crt.ValidPrincipals {
tx.Set(sshUsersTable, []byte(strings.ToLower(p)), []byte(serial))
}
}
if err := db.Update(tx); err != nil {
return errors.Wrap(err, "database Update error")
}
return nil
}
// GetSSHHostPrincipals gets a list of all valid host principals.
func (db *DB) GetSSHHostPrincipals() ([]string, error) {
entries, err := db.List(sshHostPrincipalsTable)
if err != nil {
return nil, err
}
var principals []string
for _, e := range entries {
var data sshHostPrincipalData
if err := json.Unmarshal(e.Value, &data); err != nil {
return nil, err
}
if time.Unix(int64(data.Expiry), 0).After(time.Now()) {
principals = append(principals, string(e.Key))
}
}
return principals, nil
}
// Shutdown sends a shutdown message to the database.
func (db *DB) Shutdown() error {
if db.isUp {
if err := db.Close(); err != nil {
return errors.Wrap(err, "database shutdown error")
}
db.isUp = false
}
return nil
}
// MockNoSQLDB //
type MockNoSQLDB struct {
Err error
Ret1, Ret2 interface{}
MGet func(bucket, key []byte) ([]byte, error)
MSet func(bucket, key, value []byte) error
MOpen func(dataSourceName string, opt ...database.Option) error
MClose func() error
MCreateTable func(bucket []byte) error
MDeleteTable func(bucket []byte) error
MDel func(bucket, key []byte) error
MList func(bucket []byte) ([]*database.Entry, error)
MUpdate func(tx *database.Tx) error
MCmpAndSwap func(bucket, key, old, newval []byte) ([]byte, bool, error)
}
// CmpAndSwap mock
func (m *MockNoSQLDB) CmpAndSwap(bucket, key, old, newval []byte) ([]byte, bool, error) {
if m.MCmpAndSwap != nil {
return m.MCmpAndSwap(bucket, key, old, newval)
}
if m.Ret1 == nil {
return nil, false, m.Err
}
return m.Ret1.([]byte), m.Ret2.(bool), m.Err
}
// Get mock
func (m *MockNoSQLDB) Get(bucket, key []byte) ([]byte, error) {
if m.MGet != nil {
return m.MGet(bucket, key)
}
if m.Ret1 == nil {
return nil, m.Err
}
return m.Ret1.([]byte), m.Err
}
// Set mock
func (m *MockNoSQLDB) Set(bucket, key, value []byte) error {
if m.MSet != nil {
return m.MSet(bucket, key, value)
}
return m.Err
}
// Open mock
func (m *MockNoSQLDB) Open(dataSourceName string, opt ...database.Option) error {
if m.MOpen != nil {
return m.MOpen(dataSourceName, opt...)
}
return m.Err
}
// Close mock
func (m *MockNoSQLDB) Close() error {
if m.MClose != nil {
return m.MClose()
}
return m.Err
}
// CreateTable mock
func (m *MockNoSQLDB) CreateTable(bucket []byte) error {
if m.MCreateTable != nil {
return m.MCreateTable(bucket)
}
return m.Err
}
// DeleteTable mock
func (m *MockNoSQLDB) DeleteTable(bucket []byte) error {
if m.MDeleteTable != nil {
return m.MDeleteTable(bucket)
}
return m.Err
}
// Del mock
func (m *MockNoSQLDB) Del(bucket, key []byte) error {
if m.MDel != nil {
return m.MDel(bucket, key)
}
return m.Err
}
// List mock
func (m *MockNoSQLDB) List(bucket []byte) ([]*database.Entry, error) {
if m.MList != nil {
return m.MList(bucket)
}
return m.Ret1.([]*database.Entry), m.Err
}
// Update mock
func (m *MockNoSQLDB) Update(tx *database.Tx) error {
if m.MUpdate != nil {
return m.MUpdate(tx)
}
return m.Err
}