diff --git a/authority/db_test.go b/authority/db_test.go index bd6b27ca..72684c63 100644 --- a/authority/db_test.go +++ b/authority/db_test.go @@ -8,26 +8,18 @@ import ( ) type MockAuthDB struct { - err error - ret1 interface{} - init func(*db.Config) (db.AuthDB, error) - isRevoked func(string) (bool, error) - revoke func(rci *db.RevokedCertificateInfo) error - storeCertificate func(crt *x509.Certificate) error - useToken func(id, tok string) (bool, error) - isSSHHost func(principal string) (bool, error) - storeSSHCertificate func(crt *ssh.Certificate) error - shutdown func() error -} - -func (m *MockAuthDB) Init(c *db.Config) (db.AuthDB, error) { - if m.init != nil { - return m.init(c) - } - if m.ret1 == nil { - return nil, m.err - } - return m.ret1.(*db.DB), m.err + err error + ret1 interface{} + isRevoked func(string) (bool, error) + isSSHRevoked func(string) (bool, error) + revoke func(rci *db.RevokedCertificateInfo) error + revokeSSH func(rci *db.RevokedCertificateInfo) error + storeCertificate func(crt *x509.Certificate) error + useToken func(id, tok string) (bool, error) + isSSHHost func(principal string) (bool, error) + storeSSHCertificate func(crt *ssh.Certificate) error + getSSHHostPrincipals func() ([]string, error) + shutdown func() error } func (m *MockAuthDB) IsRevoked(sn string) (bool, error) { @@ -37,6 +29,13 @@ func (m *MockAuthDB) IsRevoked(sn string) (bool, error) { return m.ret1.(bool), m.err } +func (m *MockAuthDB) IsSSHRevoked(sn string) (bool, error) { + if m.isSSHRevoked != nil { + return m.isSSHRevoked(sn) + } + return m.ret1.(bool), m.err +} + func (m *MockAuthDB) UseToken(id, tok string) (bool, error) { if m.useToken != nil { return m.useToken(id, tok) @@ -54,6 +53,13 @@ func (m *MockAuthDB) Revoke(rci *db.RevokedCertificateInfo) error { return m.err } +func (m *MockAuthDB) RevokeSSH(rci *db.RevokedCertificateInfo) error { + if m.revokeSSH != nil { + return m.revokeSSH(rci) + } + return m.err +} + func (m *MockAuthDB) StoreCertificate(crt *x509.Certificate) error { if m.storeCertificate != nil { return m.storeCertificate(crt) @@ -75,6 +81,13 @@ func (m *MockAuthDB) StoreSSHCertificate(crt *ssh.Certificate) error { return m.err } +func (m *MockAuthDB) GetSSHHostPrincipals() ([]string, error) { + if m.getSSHHostPrincipals != nil { + return m.getSSHHostPrincipals() + } + return m.ret1.([]string), m.err +} + func (m *MockAuthDB) Shutdown() error { if m.shutdown != nil { return m.shutdown()