forked from TrueCloudLab/certificates
Add SSH getHosts api
This commit is contained in:
parent
5092e8cfc2
commit
64b69374fa
6 changed files with 98 additions and 12 deletions
|
@ -257,6 +257,7 @@ func (h *caHandler) Route(r Router) {
|
|||
r.MethodFunc("POST", "/ssh/config", h.SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig)
|
||||
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost)
|
||||
r.MethodFunc("POST", "/ssh/get-hosts", h.SSHGetHosts)
|
||||
|
||||
// For compatibility with old code:
|
||||
r.MethodFunc("POST", "/re-sign", h.Renew)
|
||||
|
|
18
api/ssh.go
18
api/ssh.go
|
@ -21,6 +21,7 @@ type SSHAuthority interface {
|
|||
GetSSHFederation() (*authority.SSHKeys, error)
|
||||
GetSSHConfig(typ string, data map[string]string) ([]templates.Output, error)
|
||||
CheckSSHHost(principal string) (bool, error)
|
||||
GetSSHHosts() ([]string, error)
|
||||
}
|
||||
|
||||
// SSHSignRequest is the request body of an SSH certificate request.
|
||||
|
@ -66,6 +67,11 @@ type SSHCertificate struct {
|
|||
*ssh.Certificate `json:"omitempty"`
|
||||
}
|
||||
|
||||
// SSHGetHostsResponse
|
||||
type SSHGetHostsResponse struct {
|
||||
Hosts []string `json:"hosts"`
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface. Returns a quoted,
|
||||
// base64 encoded, openssh wire format version of the certificate.
|
||||
func (c SSHCertificate) MarshalJSON() ([]byte, error) {
|
||||
|
@ -369,3 +375,15 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
|||
Exists: exists,
|
||||
})
|
||||
}
|
||||
|
||||
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
|
||||
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
||||
hosts, err := h.Authority.GetSSHHosts()
|
||||
if err != nil {
|
||||
WriteError(w, InternalServerError(err))
|
||||
return
|
||||
}
|
||||
JSON(w, &SSHGetHostsResponse{
|
||||
Hosts: hosts,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -369,6 +369,16 @@ func (a *Authority) CheckSSHHost(principal string) (bool, error) {
|
|||
return exists, nil
|
||||
}
|
||||
|
||||
// GetSSHHosts returns a list of valid host principals.
|
||||
func (a *Authority) GetSSHHosts() ([]string, error) {
|
||||
ps, err := a.db.GetSSHHostPrincipals()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
func (a *Authority) getAddUserPrincipal() (cmd string) {
|
||||
if a.config.SSH.AddUserPrincipal == "" {
|
||||
return SSHAddUserPrincipal
|
||||
|
|
18
ca/client.go
18
ca/client.go
|
@ -611,6 +611,24 @@ func (c *Client) SSHCheckHost(principal string) (*api.SSHCheckPrincipalResponse,
|
|||
return &check, nil
|
||||
}
|
||||
|
||||
// SSHGetHostPrincipals performs the POST /ssh/check-host request to the CA with the
|
||||
// given principal.
|
||||
func (c *Client) SSHGetHostPrincipals() (*api.SSHGetHostsResponse, error) {
|
||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/get-hosts"})
|
||||
resp, err := c.client.Get(u.String())
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "client GET %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, readError(resp.Body)
|
||||
}
|
||||
var hosts api.SSHGetHostsResponse
|
||||
if err := readJSON(resp.Body, &hosts); err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||
}
|
||||
return &hosts, nil
|
||||
}
|
||||
|
||||
// RootFingerprint is a helper method that returns the current root fingerprint.
|
||||
// It does an health connection and gets the fingerprint from the TLS verified
|
||||
// chains.
|
||||
|
|
58
db/db.go
58
db/db.go
|
@ -14,12 +14,13 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
certsTable = []byte("x509_certs")
|
||||
revokedCertsTable = []byte("revoked_x509_certs")
|
||||
usedOTTTable = []byte("used_ott")
|
||||
sshCertsTable = []byte("ssh_certs")
|
||||
sshHostsTable = []byte("ssh_hosts")
|
||||
sshUsersTable = []byte("ssh_users")
|
||||
certsTable = []byte("x509_certs")
|
||||
revokedCertsTable = []byte("revoked_x509_certs")
|
||||
usedOTTTable = []byte("used_ott")
|
||||
sshCertsTable = []byte("ssh_certs")
|
||||
sshHostsTable = []byte("ssh_hosts")
|
||||
sshUsersTable = []byte("ssh_users")
|
||||
sshHostPrincipalsTable = []byte("ssh_host_principals")
|
||||
)
|
||||
|
||||
// ErrAlreadyExists can be returned if the DB attempts to set a key that has
|
||||
|
@ -42,6 +43,7 @@ type AuthDB interface {
|
|||
UseToken(id, tok string) (bool, error)
|
||||
IsSSHHost(name string) (bool, error)
|
||||
StoreSSHCertificate(crt *ssh.Certificate) error
|
||||
GetSSHHostPrincipals() ([]string, error)
|
||||
Shutdown() error
|
||||
}
|
||||
|
||||
|
@ -160,19 +162,32 @@ func (db *DB) IsSSHHost(principal string) (bool, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
type sshHostPrincipalData struct {
|
||||
Serial string
|
||||
Expiry uint64
|
||||
}
|
||||
|
||||
// StoreSSHCertificate stores an SSH certificate.
|
||||
func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error {
|
||||
var table []byte
|
||||
serial := strconv.FormatUint(crt.Serial, 10)
|
||||
tx := new(database.Tx)
|
||||
tx.Set(sshCertsTable, []byte(serial), crt.Marshal())
|
||||
if crt.CertType == ssh.HostCert {
|
||||
table = sshHostsTable
|
||||
for _, p := range crt.ValidPrincipals {
|
||||
hostPrincipalData, err := json.Marshal(sshHostPrincipalData{
|
||||
Serial: serial,
|
||||
Expiry: crt.ValidBefore,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tx.Set(sshHostsTable, []byte(strings.ToLower(p)), []byte(serial))
|
||||
tx.Set(sshHostPrincipalsTable, []byte(strings.ToLower(p)), hostPrincipalData)
|
||||
}
|
||||
} else {
|
||||
table = sshUsersTable
|
||||
}
|
||||
for _, p := range crt.ValidPrincipals {
|
||||
tx.Set(table, []byte(strings.ToLower(p)), []byte(serial))
|
||||
for _, p := range crt.ValidPrincipals {
|
||||
tx.Set(sshUsersTable, []byte(strings.ToLower(p)), []byte(serial))
|
||||
}
|
||||
}
|
||||
if err := db.Update(tx); err != nil {
|
||||
return errors.Wrap(err, "database Update error")
|
||||
|
@ -181,6 +196,25 @@ func (db *DB) StoreSSHCertificate(crt *ssh.Certificate) error {
|
|||
|
||||
}
|
||||
|
||||
// GetSSHHostPrincipals gets a list of all valid host principals.
|
||||
func (db *DB) GetSSHHostPrincipals() ([]string, error) {
|
||||
entries, err := db.List(sshHostPrincipalsTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var principals []string
|
||||
for _, e := range entries {
|
||||
var data sshHostPrincipalData
|
||||
if err := json.Unmarshal(e.Value, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if time.Unix(int64(data.Expiry), 0).After(time.Now()) {
|
||||
principals = append(principals, string(e.Key))
|
||||
}
|
||||
}
|
||||
return principals, nil
|
||||
}
|
||||
|
||||
// Shutdown sends a shutdown message to the database.
|
||||
func (db *DB) Shutdown() error {
|
||||
if db.isUp {
|
||||
|
|
|
@ -69,6 +69,11 @@ func (s *SimpleDB) StoreSSHCertificate(crt *ssh.Certificate) error {
|
|||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
// GetSSHHostPrincipals returns a "NotImplemented" error.
|
||||
func (s *SimpleDB) GetSSHHostPrincipals() ([]string, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
// Shutdown returns nil
|
||||
func (s *SimpleDB) Shutdown() error {
|
||||
return nil
|
||||
|
|
Loading…
Reference in a new issue