diff --git a/api/api.go b/api/api.go index da6309fd..9b795cf0 100644 --- a/api/api.go +++ b/api/api.go @@ -52,6 +52,16 @@ type Authority interface { Version() authority.Version } +var errAuthority = errors.New("authority is not in context") + +func mustAuthority(ctx context.Context) Authority { + a, ok := authority.FromContext(ctx) + if !ok { + panic(errAuthority) + } + return a +} + // TimeDuration is an alias of provisioner.TimeDuration type TimeDuration = provisioner.TimeDuration @@ -251,40 +261,40 @@ func New(auth Authority) RouterHandler { } func (h *caHandler) Route(r Router) { - r.MethodFunc("GET", "/version", h.Version) - r.MethodFunc("GET", "/health", h.Health) - r.MethodFunc("GET", "/root/{sha}", h.Root) - r.MethodFunc("POST", "/sign", h.Sign) - r.MethodFunc("POST", "/renew", h.Renew) - r.MethodFunc("POST", "/rekey", h.Rekey) - r.MethodFunc("POST", "/revoke", h.Revoke) - r.MethodFunc("GET", "/provisioners", h.Provisioners) - r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) - r.MethodFunc("GET", "/roots", h.Roots) - r.MethodFunc("GET", "/roots.pem", h.RootsPEM) - r.MethodFunc("GET", "/federation", h.Federation) + r.MethodFunc("GET", "/version", Version) + r.MethodFunc("GET", "/health", Health) + r.MethodFunc("GET", "/root/{sha}", Root) + r.MethodFunc("POST", "/sign", Sign) + r.MethodFunc("POST", "/renew", Renew) + r.MethodFunc("POST", "/rekey", Rekey) + r.MethodFunc("POST", "/revoke", Revoke) + r.MethodFunc("GET", "/provisioners", Provisioners) + r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey) + r.MethodFunc("GET", "/roots", Roots) + r.MethodFunc("GET", "/roots.pem", RootsPEM) + r.MethodFunc("GET", "/federation", Federation) // SSH CA - r.MethodFunc("POST", "/ssh/sign", h.SSHSign) - r.MethodFunc("POST", "/ssh/renew", h.SSHRenew) - r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke) - r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey) - r.MethodFunc("GET", "/ssh/roots", h.SSHRoots) - r.MethodFunc("GET", "/ssh/federation", h.SSHFederation) - r.MethodFunc("POST", "/ssh/config", h.SSHConfig) - r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig) - r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost) - r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts) - r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion) + r.MethodFunc("POST", "/ssh/sign", SSHSign) + r.MethodFunc("POST", "/ssh/renew", SSHRenew) + r.MethodFunc("POST", "/ssh/revoke", SSHRevoke) + r.MethodFunc("POST", "/ssh/rekey", SSHRekey) + r.MethodFunc("GET", "/ssh/roots", SSHRoots) + r.MethodFunc("GET", "/ssh/federation", SSHFederation) + r.MethodFunc("POST", "/ssh/config", SSHConfig) + r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig) + r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost) + r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts) + r.MethodFunc("POST", "/ssh/bastion", SSHBastion) // For compatibility with old code: - r.MethodFunc("POST", "/re-sign", h.Renew) - r.MethodFunc("POST", "/sign-ssh", h.SSHSign) - r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts) + r.MethodFunc("POST", "/re-sign", Renew) + r.MethodFunc("POST", "/sign-ssh", SSHSign) + r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts) } // Version is an HTTP handler that returns the version of the server. -func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { - v := h.Authority.Version() +func Version(w http.ResponseWriter, r *http.Request) { + v := mustAuthority(r.Context()).Version() render.JSON(w, VersionResponse{ Version: v.Version, RequireClientAuthentication: v.RequireClientAuthentication, @@ -292,17 +302,17 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { } // Health is an HTTP handler that returns the status of the server. -func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { +func Health(w http.ResponseWriter, r *http.Request) { render.JSON(w, HealthResponse{Status: "ok"}) } // Root is an HTTP handler that using the SHA256 from the URL, returns the root // certificate for the given SHA256. -func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { +func Root(w http.ResponseWriter, r *http.Request) { sha := chi.URLParam(r, "sha") sum := strings.ToLower(strings.ReplaceAll(sha, "-", "")) // Load root certificate with the - cert, err := h.Authority.Root(sum) + cert, err := mustAuthority(r.Context()).Root(sum) if err != nil { render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) return @@ -320,18 +330,19 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate { } // Provisioners returns the list of provisioners configured in the authority. -func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { +func Provisioners(w http.ResponseWriter, r *http.Request) { cursor, limit, err := ParseCursor(r) if err != nil { render.Error(w, err) return } - p, next, err := h.Authority.GetProvisioners(cursor, limit) + p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit) if err != nil { render.Error(w, errs.InternalServerErr(err)) return } + render.JSON(w, &ProvisionersResponse{ Provisioners: p, NextCursor: next, @@ -339,19 +350,20 @@ func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { } // ProvisionerKey returns the encrypted key of a provisioner by it's key id. -func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { +func ProvisionerKey(w http.ResponseWriter, r *http.Request) { kid := chi.URLParam(r, "kid") - key, err := h.Authority.GetEncryptedKey(kid) + key, err := mustAuthority(r.Context()).GetEncryptedKey(kid) if err != nil { render.Error(w, errs.NotFoundErr(err)) return } + render.JSON(w, &ProvisionerKeyResponse{key}) } // Roots returns all the root certificates for the CA. -func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { - roots, err := h.Authority.GetRoots() +func Roots(w http.ResponseWriter, r *http.Request) { + roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { render.Error(w, errs.ForbiddenErr(err, "error getting roots")) return @@ -368,8 +380,8 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { } // RootsPEM returns all the root certificates for the CA in PEM format. -func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { - roots, err := h.Authority.GetRoots() +func RootsPEM(w http.ResponseWriter, r *http.Request) { + roots, err := mustAuthority(r.Context()).GetRoots() if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -391,8 +403,8 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { } // Federation returns all the public certificates in the federation. -func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { - federated, err := h.Authority.GetFederation() +func Federation(w http.ResponseWriter, r *http.Request) { + federated, err := mustAuthority(r.Context()).GetFederation() if err != nil { render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) return diff --git a/api/rekey.go b/api/rekey.go index 3116cf74..cda843a3 100644 --- a/api/rekey.go +++ b/api/rekey.go @@ -27,7 +27,7 @@ func (s *RekeyRequest) Validate() error { } // Rekey is similar to renew except that the certificate will be renewed with new key from csr. -func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { +func Rekey(w http.ResponseWriter, r *http.Request) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { render.Error(w, errs.BadRequest("missing client certificate")) return @@ -44,7 +44,8 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { return } - certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) + a := mustAuthority(r.Context()) + certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) if err != nil { render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) return @@ -60,6 +61,6 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } diff --git a/api/renew.go b/api/renew.go index 9c4bff32..6e9f680f 100644 --- a/api/renew.go +++ b/api/renew.go @@ -16,14 +16,15 @@ const ( // Renew uses the information of certificate in the TLS connection to create a // new one. -func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { - cert, err := h.getPeerCertificate(r) +func Renew(w http.ResponseWriter, r *http.Request) { + cert, err := getPeerCertificate(r) if err != nil { render.Error(w, err) return } - certChain, err := h.Authority.Renew(cert) + a := mustAuthority(r.Context()) + certChain, err := a.Renew(cert) if err != nil { render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) return @@ -39,17 +40,18 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } -func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) { +func getPeerCertificate(r *http.Request) (*x509.Certificate, error) { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { return r.TLS.PeerCertificates[0], nil } if s := r.Header.Get(authorizationHeader); s != "" { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { - return h.Authority.AuthorizeRenewToken(r.Context(), parts[1]) + ctx := r.Context() + return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1]) } } return nil, errs.BadRequest("missing client certificate") diff --git a/api/revoke.go b/api/revoke.go index c9da2c18..aebbb875 100644 --- a/api/revoke.go +++ b/api/revoke.go @@ -1,7 +1,6 @@ package api import ( - "context" "net/http" "golang.org/x/crypto/ocsp" @@ -49,7 +48,7 @@ func (r *RevokeRequest) Validate() (err error) { // NOTE: currently only Passive revocation is supported. // // TODO: Add CRL and OCSP support. -func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { +func Revoke(w http.ResponseWriter, r *http.Request) { var body RevokeRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,12 +67,14 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { PassiveOnly: body.Passive, } - ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) + ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod) + a := mustAuthority(ctx) + // A token indicates that we are using the api via a provisioner token, // otherwise it is assumed that the certificate is revoking itself over mTLS. if len(body.OTT) > 0 { logOtt(w, body.OTT) - if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { + if _, err := a.Authorize(ctx, body.OTT); err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } @@ -98,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { opts.MTLS = true } - if err := h.Authority.Revoke(ctx, opts); err != nil { + if err := a.Revoke(ctx, opts); err != nil { render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) return } diff --git a/api/sign.go b/api/sign.go index b6bfcc8b..b263e2e9 100644 --- a/api/sign.go +++ b/api/sign.go @@ -49,7 +49,7 @@ type SignResponse struct { // Sign is an HTTP handler that reads a certificate request and an // one-time-token (ott) from the body and creates a new certificate with the // information in the certificate request. -func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { +func Sign(w http.ResponseWriter, r *http.Request) { var body SignRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,13 +68,14 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { TemplateData: body.TemplateData, } - signOpts, err := h.Authority.AuthorizeSign(body.OTT) + a := mustAuthority(r.Context()) + signOpts, err := a.AuthorizeSign(body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } - certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) + certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) return @@ -89,6 +90,6 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { ServerPEM: certChainPEM[0], CaPEM: caPEM, CertChainPEM: certChainPEM, - TLSOptions: h.Authority.GetTLSOptions(), + TLSOptions: a.GetTLSOptions(), }, http.StatusCreated) } diff --git a/api/ssh.go b/api/ssh.go index 3b0de7c1..f3056fc5 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -250,7 +250,7 @@ type SSHBastionResponse struct { // SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { +func SSHSign(w http.ResponseWriter, r *http.Request) { var body SSHSignRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -288,13 +288,15 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + + a := mustAuthority(ctx) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } - cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) + cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return @@ -302,7 +304,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { var addUserCertificate *SSHCertificate if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { - addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) + addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) return @@ -315,7 +317,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { if cr := body.IdentityCSR.CertificateRequest; cr != nil { ctx := authority.NewContextWithSkipTokenReuse(r.Context()) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -327,7 +329,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { NotAfter: time.Unix(int64(cert.ValidBefore), 0), }) - certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) + certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) return @@ -344,8 +346,9 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { // SSHRoots is an HTTP handler that returns the SSH public keys for user and host // certificates. -func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHRoots(r.Context()) +func SSHRoots(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + keys, err := mustAuthority(ctx).GetSSHRoots(ctx) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -369,8 +372,9 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { // SSHFederation is an HTTP handler that returns the federated SSH public keys // for user and host certificates. -func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { - keys, err := h.Authority.GetSSHFederation(r.Context()) +func SSHFederation(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + keys, err := mustAuthority(ctx).GetSSHFederation(ctx) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -394,7 +398,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { // SSHConfig is an HTTP handler that returns rendered templates for ssh clients // and servers. -func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { +func SSHConfig(w http.ResponseWriter, r *http.Request) { var body SSHConfigRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -405,7 +409,8 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { return } - ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) + ctx := r.Context() + ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -426,7 +431,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { } // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. -func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { +func SSHCheckHost(w http.ResponseWriter, r *http.Request) { var body SSHCheckPrincipalRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -437,7 +442,8 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { return } - exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) + ctx := r.Context() + exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -448,13 +454,14 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { } // SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts. -func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { +func SSHGetHosts(w http.ResponseWriter, r *http.Request) { var cert *x509.Certificate if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { cert = r.TLS.PeerCertificates[0] } - hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) + ctx := r.Context() + hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert) if err != nil { render.Error(w, errs.InternalServerErr(err)) return @@ -465,7 +472,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { } // SSHBastion provides returns the bastion configured if any. -func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { +func SSHBastion(w http.ResponseWriter, r *http.Request) { var body SSHBastionRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -476,7 +483,8 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { return } - bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) + ctx := r.Context() + bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname) if err != nil { render.Error(w, errs.InternalServerErr(err)) return diff --git a/api/sshRekey.go b/api/sshRekey.go index 92278950..184f208a 100644 --- a/api/sshRekey.go +++ b/api/sshRekey.go @@ -39,7 +39,7 @@ type SSHRekeyResponse struct { // SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { +func SSHRekey(w http.ResponseWriter, r *http.Request) { var body SSHRekeyRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -59,7 +59,9 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) - signOpts, err := h.Authority.Authorize(ctx, body.OTT) + + a := mustAuthority(ctx) + signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -70,7 +72,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { return } - newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) + newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) return @@ -80,7 +82,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0) - identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) + identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return diff --git a/api/sshRenew.go b/api/sshRenew.go index 78d16fa6..606b45bb 100644 --- a/api/sshRenew.go +++ b/api/sshRenew.go @@ -37,7 +37,7 @@ type SSHRenewResponse struct { // SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token // (ott) from the body and creates a new SSH certificate with the information in // the request. -func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { +func SSHRenew(w http.ResponseWriter, r *http.Request) { var body SSHRenewRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -51,7 +51,8 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) - _, err := h.Authority.Authorize(ctx, body.OTT) + a := mustAuthority(ctx) + _, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) return @@ -62,7 +63,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { return } - newCert, err := h.Authority.RenewSSH(ctx, oldCert) + newCert, err := a.RenewSSH(ctx, oldCert) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) return @@ -72,7 +73,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0) - identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) + identity, err := renewIdentityCertificate(r, notBefore, notAfter) if err != nil { render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) return @@ -85,7 +86,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { } // renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the -func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { +func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { return nil, nil } @@ -105,7 +106,7 @@ func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfte cert.NotAfter = notAfter } - certChain, err := h.Authority.Renew(cert) + certChain, err := mustAuthority(r.Context()).Renew(cert) if err != nil { return nil, err } diff --git a/api/sshRevoke.go b/api/sshRevoke.go index a33082cd..d377def9 100644 --- a/api/sshRevoke.go +++ b/api/sshRevoke.go @@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) { // Revoke supports handful of different methods that revoke a Certificate. // // NOTE: currently only Passive revocation is supported. -func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { +func SSHRevoke(w http.ResponseWriter, r *http.Request) { var body SSHRevokeRequest if err := read.JSON(r.Body, &body); err != nil { render.Error(w, errs.BadRequestErr(err, "error reading request body")) @@ -68,16 +68,19 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { } ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod) + a := mustAuthority(ctx) + // A token indicates that we are using the api via a provisioner token, // otherwise it is assumed that the certificate is revoking itself over mTLS. logOtt(w, body.OTT) - if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { + + if _, err := a.Authorize(ctx, body.OTT); err != nil { render.Error(w, errs.UnauthorizedErr(err)) return } opts.OTT = body.OTT - if err := h.Authority.Revoke(ctx, opts); err != nil { + if err := a.Revoke(ctx, opts); err != nil { render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) return }