[acme db interface] unit test progress

This commit is contained in:
max furman 2021-03-10 23:05:46 -08:00
parent bb8d54e596
commit f71e27e787
7 changed files with 47 additions and 61 deletions

View file

@ -40,7 +40,7 @@ func newProv() provisioner.Interface {
return p return p
} }
func TestNewAccountRequestValidate(t *testing.T) { func TestNewAccountRequest_Validate(t *testing.T) {
type test struct { type test struct {
nar *NewAccountRequest nar *NewAccountRequest
err *acme.Error err *acme.Error
@ -96,7 +96,7 @@ func TestNewAccountRequestValidate(t *testing.T) {
} }
} }
func TestUpdateAccountRequestValidate(t *testing.T) { func TestUpdateAccountRequest_Validate(t *testing.T) {
type test struct { type test struct {
uar *UpdateAccountRequest uar *UpdateAccountRequest
err *acme.Error err *acme.Error

View file

@ -2,7 +2,9 @@ package api
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -259,10 +261,12 @@ func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) {
return return
} }
certBytes, err := cert.ToACME() var certBytes []byte
if err != nil { for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) {
api.WriteError(w, acme.WrapErrorISE(err, "error converting cert to ACME representation")) certBytes = append(certBytes, pem.EncodeToMemory(&pem.Block{
return Type: "CERTIFICATE",
Bytes: c.Raw,
})...)
} }
api.LogCertificate(w, cert.Leaf) api.LogCertificate(w, cert.Leaf)

View file

@ -3,6 +3,7 @@ package api
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
@ -47,7 +48,7 @@ func TestHandler_GetNonce(t *testing.T) {
} }
func TestHandler_GetDirectory(t *testing.T) { func TestHandler_GetDirectory(t *testing.T) {
linker := NewLinker("acme", "ca.smallstep.com") linker := NewLinker("ca.smallstep.com", "acme")
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
@ -306,7 +307,7 @@ func TestHandler_GetCertificate(t *testing.T) {
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
} }
}, },
"fail/getCertificate-error": func(t *testing.T) test { "fail/db.GetCertificate-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -319,7 +320,7 @@ func TestHandler_GetCertificate(t *testing.T) {
err: acme.NewErrorISE("force"), err: acme.NewErrorISE("force"),
} }
}, },
"fail/decode-leaf-for-loggger": func(t *testing.T) test { "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -327,28 +328,12 @@ func TestHandler_GetCertificate(t *testing.T) {
db: &acme.MockDB{ db: &acme.MockDB{
MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) {
assert.Equals(t, id, certID) assert.Equals(t, id, certID)
return &acme.Certificate{}, nil return &acme.Certificate{AccountID: "foo"}, nil
}, },
}, },
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 401,
err: acme.NewErrorISE("failed to decode any certificates from generated certBytes"), err: acme.NewError(acme.ErrorUnauthorizedType, "account id mismatch"),
}
},
"fail/parse-x509-leaf-for-logger": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{
db: &acme.MockDB{
MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) {
assert.Equals(t, id, certID)
return &acme.Certificate{}, nil
},
},
ctx: ctx,
statusCode: 500,
err: acme.NewErrorISE("failed to parse generated leaf certificate"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -359,7 +344,13 @@ func TestHandler_GetCertificate(t *testing.T) {
db: &acme.MockDB{ db: &acme.MockDB{
MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) { MockGetCertificate: func(ctx context.Context, id string) (*acme.Certificate, error) {
assert.Equals(t, id, certID) assert.Equals(t, id, certID)
return &acme.Certificate{}, nil return &acme.Certificate{
AccountID: "accID",
OrderID: "ordID",
Leaf: leaf,
Intermediates: []*x509.Certificate{inter, root},
ID: id,
}, nil
}, },
}, },
ctx: ctx, ctx: ctx,

View file

@ -411,7 +411,7 @@ const (
func accountFromContext(ctx context.Context) (*acme.Account, error) { func accountFromContext(ctx context.Context) (*acme.Account, error) {
val, ok := ctx.Value(accContextKey).(*acme.Account) val, ok := ctx.Value(accContextKey).(*acme.Account)
if !ok || val == nil { if !ok || val == nil {
return nil, acme.NewErrorISE("account not in context") return nil, acme.NewError(acme.ErrorAccountDoesNotExistType, "account not in context")
} }
return val, nil return val, nil
} }

View file

@ -81,7 +81,7 @@ func Test_baseURLFromRequest(t *testing.T) {
} }
} }
func TestHandlerBaseURLFromRequest(t *testing.T) { func TestHandler_baseURLFromRequest(t *testing.T) {
h := &Handler{} h := &Handler{}
req := httptest.NewRequest("GET", "/foo", nil) req := httptest.NewRequest("GET", "/foo", nil)
req.Host = "test.ca.smallstep.com:8080" req.Host = "test.ca.smallstep.com:8080"
@ -107,7 +107,7 @@ func TestHandlerBaseURLFromRequest(t *testing.T) {
h.baseURLFromRequest(next)(w, req) h.baseURLFromRequest(next)(w, req)
} }
func TestHandler_AddNonce(t *testing.T) { func TestHandler_addNonce(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-nonce" url := "https://ca.smallstep.com/acme/new-nonce"
type test struct { type test struct {
db acme.DB db acme.DB
@ -226,7 +226,7 @@ func TestHandler_addDirLink(t *testing.T) {
} }
} }
func TestHandler_VerifyContentType(t *testing.T) { func TestHandler_verifyContentType(t *testing.T) {
prov := newProv() prov := newProv()
provName := prov.GetName() provName := prov.GetName()
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
@ -340,7 +340,7 @@ func TestHandler_VerifyContentType(t *testing.T) {
} }
} }
func TestHandlerIsPostAsGet(t *testing.T) { func TestHandler_isPostAsGet(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-account" url := "https://ca.smallstep.com/acme/new-account"
type test struct { type test struct {
ctx context.Context ctx context.Context
@ -417,7 +417,7 @@ func (errReader) Close() error {
return nil return nil
} }
func TestHandlerParseJWS(t *testing.T) { func TestHandler_parseJWS(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-account" url := "https://ca.smallstep.com/acme/new-account"
type test struct { type test struct {
next nextHTTP next nextHTTP
@ -498,7 +498,7 @@ func TestHandlerParseJWS(t *testing.T) {
} }
} }
func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) { func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
_pub := jwk.Public() _pub := jwk.Public()
@ -558,7 +558,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
_pub := _jwk.Public() _pub := _jwk.Public()
ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, jwsContextKey, &_pub) ctx = context.WithValue(ctx, jwkContextKey, &_pub)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -570,7 +570,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
clone := &_pub clone := &_pub
clone.Algorithm = jose.HS256 clone.Algorithm = jose.HS256
ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, jwsContextKey, clone) ctx = context.WithValue(ctx, jwkContextKey, clone)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
@ -579,7 +579,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, jwsContextKey, pub) ctx = context.WithValue(ctx, jwkContextKey, pub)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
@ -600,7 +600,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
clone := &_pub clone := &_pub
clone.Algorithm = "" clone.Algorithm = ""
ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS) ctx := context.WithValue(context.Background(), jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, jwsContextKey, pub) ctx = context.WithValue(ctx, jwkContextKey, pub)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
@ -624,7 +624,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
_parsed, err := jose.ParseJWS(_raw) _parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) ctx := context.WithValue(context.Background(), jwsContextKey, _parsed)
ctx = context.WithValue(ctx, jwsContextKey, pub) ctx = context.WithValue(ctx, jwkContextKey, pub)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
@ -648,7 +648,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
_parsed, err := jose.ParseJWS(_raw) _parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), jwsContextKey, _parsed) ctx := context.WithValue(context.Background(), jwsContextKey, _parsed)
ctx = context.WithValue(ctx, jwsContextKey, pub) ctx = context.WithValue(ctx, jwkContextKey, pub)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
@ -697,7 +697,7 @@ func TestHandlerVerifyAndExtractJWSPayload(t *testing.T) {
} }
} }
func TestHandlerLookupJWK(t *testing.T) { func TestHandler_lookupJWK(t *testing.T) {
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
@ -899,7 +899,7 @@ func TestHandlerLookupJWK(t *testing.T) {
} }
} }
func TestHandlerExtractJWK(t *testing.T) { func TestHandler_extractJWK(t *testing.T) {
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
@ -1095,7 +1095,7 @@ func TestHandlerExtractJWK(t *testing.T) {
} }
} }
func TestHandlerValidateJWS(t *testing.T) { func TestHandler_validateJWS(t *testing.T) {
url := "https://ca.smallstep.com/acme/account/1234" url := "https://ca.smallstep.com/acme/account/1234"
type test struct { type test struct {
db acme.DB db acme.DB

View file

@ -2,7 +2,6 @@ package acme
import ( import (
"crypto/x509" "crypto/x509"
"encoding/pem"
) )
// Certificate options with which to create and store a cert object. // Certificate options with which to create and store a cert object.
@ -13,15 +12,3 @@ type Certificate struct {
Leaf *x509.Certificate Leaf *x509.Certificate
Intermediates []*x509.Certificate Intermediates []*x509.Certificate
} }
// ToACME encodes the entire X509 chain into a PEM list.
func (cert *Certificate) ToACME() ([]byte, error) {
var ret []byte
for _, c := range append([]*x509.Certificate{cert.Leaf}, cert.Intermediates...) {
ret = append(ret, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: c.Raw,
})...)
}
return ret, nil
}

View file

@ -271,6 +271,10 @@ type Error struct {
// NewError creates a new Error type. // NewError creates a new Error type.
func NewError(pt ProblemType, msg string, args ...interface{}) *Error { func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
return newError(pt, errors.Errorf(msg, args...))
}
func newError(pt ProblemType, err error) *Error {
meta, ok := errorMap[pt] meta, ok := errorMap[pt]
if !ok { if !ok {
meta = errorServerInternalMetadata meta = errorServerInternalMetadata
@ -278,7 +282,7 @@ func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
Type: meta.typ, Type: meta.typ,
Detail: meta.details, Detail: meta.details,
Status: meta.status, Status: meta.status,
Err: errors.Errorf("unrecognized problemType %v", pt), Err: err,
} }
} }
@ -286,7 +290,7 @@ func NewError(pt ProblemType, msg string, args ...interface{}) *Error {
Type: meta.typ, Type: meta.typ,
Detail: meta.details, Detail: meta.details,
Status: meta.status, Status: meta.status,
Err: errors.Errorf(msg, args...), Err: err,
} }
} }
@ -308,7 +312,7 @@ func WrapError(typ ProblemType, err error, msg string, args ...interface{}) *Err
} }
return e return e
default: default:
return NewError(ErrorServerInternalType, msg, args...) return newError(typ, errors.Wrapf(err, msg, args...))
} }
} }