forked from TrueCloudLab/certificates
Store and verify Acme account location (#1386)
* Store and verify account location on acme requests Co-authored-by: Herman Slatman <hslatman@users.noreply.github.com> Co-authored-by: Mariano Cano <mariano@smallstep.com>
This commit is contained in:
parent
cbbc54e980
commit
7731edd816
8 changed files with 237 additions and 91 deletions
|
@ -20,6 +20,16 @@ type Account struct {
|
||||||
Status Status `json:"status"`
|
Status Status `json:"status"`
|
||||||
OrdersURL string `json:"orders"`
|
OrdersURL string `json:"orders"`
|
||||||
ExternalAccountBinding interface{} `json:"externalAccountBinding,omitempty"`
|
ExternalAccountBinding interface{} `json:"externalAccountBinding,omitempty"`
|
||||||
|
LocationPrefix string `json:"-"`
|
||||||
|
ProvisionerName string `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLocation returns the URL location of the given account.
|
||||||
|
func (a *Account) GetLocation() string {
|
||||||
|
if a.LocationPrefix == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.LocationPrefix + a.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToLog enables response logging.
|
// ToLog enables response logging.
|
||||||
|
@ -72,6 +82,7 @@ func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions {
|
||||||
IPRanges: p.X509.Allowed.IPRanges,
|
IPRanges: p.X509.Allowed.IPRanges,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions {
|
func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions {
|
||||||
if p == nil {
|
if p == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -66,6 +66,23 @@ func TestKeyToID(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_GetLocation(t *testing.T) {
|
||||||
|
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
|
||||||
|
type test struct {
|
||||||
|
acc *Account
|
||||||
|
exp string
|
||||||
|
}
|
||||||
|
tests := map[string]test{
|
||||||
|
"empty": {acc: &Account{LocationPrefix: ""}, exp: ""},
|
||||||
|
"not-empty": {acc: &Account{ID: "bar", LocationPrefix: locationPrefix}, exp: locationPrefix + "bar"},
|
||||||
|
}
|
||||||
|
for name, tc := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert.Equals(t, tc.acc.GetLocation(), tc.exp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_IsValid(t *testing.T) {
|
func TestAccount_IsValid(t *testing.T) {
|
||||||
type test struct {
|
type test struct {
|
||||||
acc *Account
|
acc *Account
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -67,6 +68,12 @@ func (u *UpdateAccountRequest) Validate() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getAccountLocationPath returns the current account URL location.
|
||||||
|
// Returned location will be of the form: https://<ca-url>/acme/<provisioner>/account/<accID>
|
||||||
|
func getAccountLocationPath(ctx context.Context, linker acme.Linker, accID string) string {
|
||||||
|
return linker.GetLink(ctx, acme.AccountLinkType, accID)
|
||||||
|
}
|
||||||
|
|
||||||
// NewAccount is the handler resource for creating new ACME accounts.
|
// NewAccount is the handler resource for creating new ACME accounts.
|
||||||
func NewAccount(w http.ResponseWriter, r *http.Request) {
|
func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
@ -125,9 +132,11 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
acc = &acme.Account{
|
acc = &acme.Account{
|
||||||
Key: jwk,
|
Key: jwk,
|
||||||
Contact: nar.Contact,
|
Contact: nar.Contact,
|
||||||
Status: acme.StatusValid,
|
Status: acme.StatusValid,
|
||||||
|
LocationPrefix: getAccountLocationPath(ctx, linker, ""),
|
||||||
|
ProvisionerName: prov.GetName(),
|
||||||
}
|
}
|
||||||
if err := db.CreateAccount(ctx, acc); err != nil {
|
if err := db.CreateAccount(ctx, acc); err != nil {
|
||||||
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
|
render.Error(w, acme.WrapErrorISE(err, "error creating account"))
|
||||||
|
@ -152,7 +161,7 @@ func NewAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
linker.LinkAccount(ctx, acc)
|
linker.LinkAccount(ctx, acc)
|
||||||
|
|
||||||
w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID))
|
w.Header().Set("Location", getAccountLocationPath(ctx, linker, acc.ID))
|
||||||
render.JSONStatus(w, acc, httpStatus)
|
render.JSONStatus(w, acc, httpStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
|
@ -16,7 +17,6 @@ import (
|
||||||
"github.com/smallstep/certificates/api/render"
|
"github.com/smallstep/certificates/api/render"
|
||||||
"github.com/smallstep/certificates/authority/provisioner"
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
"github.com/smallstep/certificates/logging"
|
"github.com/smallstep/certificates/logging"
|
||||||
"github.com/smallstep/nosql"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||||
|
@ -293,7 +293,6 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
db := acme.MustDatabaseFromContext(ctx)
|
db := acme.MustDatabaseFromContext(ctx)
|
||||||
linker := acme.MustLinkerFromContext(ctx)
|
|
||||||
|
|
||||||
jws, err := jwsFromContext(ctx)
|
jws, err := jwsFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -301,19 +300,16 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
|
|
||||||
kid := jws.Signatures[0].Protected.KeyID
|
kid := jws.Signatures[0].Protected.KeyID
|
||||||
if !strings.HasPrefix(kid, kidPrefix) {
|
if kid == "" {
|
||||||
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
render.Error(w, acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"))
|
||||||
"kid does not have required prefix; expected %s, but got %s",
|
|
||||||
kidPrefix, kid))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
accID := strings.TrimPrefix(kid, kidPrefix)
|
accID := path.Base(kid)
|
||||||
acc, err := db.GetAccount(ctx, accID)
|
acc, err := db.GetAccount(ctx, accID)
|
||||||
switch {
|
switch {
|
||||||
case nosql.IsErrNotFound(err):
|
case acme.IsErrNotFound(err):
|
||||||
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
|
||||||
return
|
return
|
||||||
case err != nil:
|
case err != nil:
|
||||||
|
@ -324,6 +320,45 @@ func lookupJWK(next nextHTTP) nextHTTP {
|
||||||
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account is not active"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if storedLocation := acc.GetLocation(); storedLocation != "" {
|
||||||
|
if kid != storedLocation {
|
||||||
|
// ACME accounts should have a stored location equivalent to the
|
||||||
|
// kid in the ACME request.
|
||||||
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
|
"kid does not match stored account location; expected %s, but got %s",
|
||||||
|
storedLocation, kid))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the provisioner with which the account was created
|
||||||
|
// matches the provisioner in the request URL.
|
||||||
|
reqProv := acme.MustProvisionerFromContext(ctx)
|
||||||
|
reqProvName := reqProv.GetName()
|
||||||
|
accProvName := acc.ProvisionerName
|
||||||
|
if reqProvName != accProvName {
|
||||||
|
// Provisioner in the URL must match the provisioner with
|
||||||
|
// which the account was created.
|
||||||
|
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
|
"account provisioner does not match requested provisioner; account provisioner = %s, requested provisioner = %s",
|
||||||
|
accProvName, reqProvName))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// This code will only execute for old ACME accounts that do
|
||||||
|
// not have a cached location. The following validation was
|
||||||
|
// the original implementation of the `kid` check which has
|
||||||
|
// since been deprecated. However, the code will remain to
|
||||||
|
// ensure consistent behavior for old ACME accounts.
|
||||||
|
linker := acme.MustLinkerFromContext(ctx)
|
||||||
|
kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
|
||||||
|
if !strings.HasPrefix(kid, kidPrefix) {
|
||||||
|
render.Error(w, acme.NewError(acme.ErrorMalformedType,
|
||||||
|
"kid does not have required prefix; expected %s, but got %s",
|
||||||
|
kidPrefix, kid))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
ctx = context.WithValue(ctx, accContextKey, acc)
|
ctx = context.WithValue(ctx, accContextKey, acc)
|
||||||
ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
|
ctx = context.WithValue(ctx, jwkContextKey, acc.Key)
|
||||||
next(w, r.WithContext(ctx))
|
next(w, r.WithContext(ctx))
|
||||||
|
|
|
@ -17,7 +17,6 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/smallstep/assert"
|
"github.com/smallstep/assert"
|
||||||
"github.com/smallstep/certificates/acme"
|
"github.com/smallstep/certificates/acme"
|
||||||
"github.com/smallstep/nosql/database"
|
|
||||||
"go.step.sm/crypto/jose"
|
"go.step.sm/crypto/jose"
|
||||||
"go.step.sm/crypto/keyutil"
|
"go.step.sm/crypto/keyutil"
|
||||||
)
|
)
|
||||||
|
@ -678,31 +677,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
|
err: acme.NewError(acme.ErrorMalformedType, "signature missing 'kid'"),
|
||||||
}
|
|
||||||
},
|
|
||||||
"fail/bad-kid-prefix": func(t *testing.T) test {
|
|
||||||
_so := new(jose.SignerOptions)
|
|
||||||
_so.WithHeader("kid", "foo")
|
|
||||||
_signer, err := jose.NewSigner(jose.SigningKey{
|
|
||||||
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
|
|
||||||
Key: jwk.Key,
|
|
||||||
}, _so)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
_jws, err := _signer.Sign([]byte("baz"))
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
_raw, err := _jws.CompactSerialize()
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
_parsed, err := jose.ParseJWS(_raw)
|
|
||||||
assert.FatalError(t, err)
|
|
||||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, _parsed)
|
|
||||||
return test{
|
|
||||||
db: &acme.MockDB{},
|
|
||||||
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
|
||||||
ctx: ctx,
|
|
||||||
statusCode: 400,
|
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/account-not-found": func(t *testing.T) test {
|
"fail/account-not-found": func(t *testing.T) test {
|
||||||
|
@ -713,7 +688,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
db: &acme.MockDB{
|
db: &acme.MockDB{
|
||||||
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
|
||||||
assert.Equals(t, accID, accID)
|
assert.Equals(t, accID, accID)
|
||||||
return nil, database.ErrNotFound
|
return nil, acme.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -754,7 +729,77 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"),
|
err: acme.NewError(acme.ErrorUnauthorizedType, "account is not active"),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"fail/account-with-location-prefix/bad-kid": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{LocationPrefix: "foobar", Status: "valid"}
|
||||||
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
|
return test{
|
||||||
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, id, accID)
|
||||||
|
return acc, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
statusCode: http.StatusUnauthorized,
|
||||||
|
err: acme.NewError(acme.ErrorUnauthorizedType, "kid does not match stored account location; expected foobar, but %q", prefix+accID),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"fail/account-with-location-prefix/bad-provisioner": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerName: "other"}
|
||||||
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
|
return test{
|
||||||
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, id, accID)
|
||||||
|
return acc, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_acc, err := accountFromContext(r.Context())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, _acc, acc)
|
||||||
|
_jwk, err := jwkFromContext(r.Context())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, _jwk, jwk)
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
statusCode: http.StatusUnauthorized,
|
||||||
|
err: acme.NewError(acme.ErrorUnauthorizedType,
|
||||||
|
"account provisioner does not match requested provisioner; account provisioner = %s, reqested provisioner = %s",
|
||||||
|
prov.GetName(), "other"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/account-with-location-prefix": func(t *testing.T) test {
|
||||||
|
acc := &acme.Account{LocationPrefix: prefix + accID, Status: "valid", Key: jwk, ProvisionerName: prov.GetName()}
|
||||||
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
|
return test{
|
||||||
|
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
|
||||||
|
db: &acme.MockDB{
|
||||||
|
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
|
||||||
|
assert.Equals(t, id, accID)
|
||||||
|
return acc, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
next: func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_acc, err := accountFromContext(r.Context())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, _acc, acc)
|
||||||
|
_jwk, err := jwkFromContext(r.Context())
|
||||||
|
assert.FatalError(t, err)
|
||||||
|
assert.Equals(t, _jwk, jwk)
|
||||||
|
w.Write(testBody)
|
||||||
|
},
|
||||||
|
statusCode: http.StatusOK,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ok/account-without-location-prefix": func(t *testing.T) test {
|
||||||
acc := &acme.Account{Status: "valid", Key: jwk}
|
acc := &acme.Account{Status: "valid", Key: jwk}
|
||||||
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
ctx := acme.NewProvisionerContext(context.Background(), prov)
|
||||||
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
|
||||||
|
|
|
@ -12,6 +12,12 @@ import (
|
||||||
// account.
|
// account.
|
||||||
var ErrNotFound = errors.New("not found")
|
var ErrNotFound = errors.New("not found")
|
||||||
|
|
||||||
|
// IsErrNotFound returns true if the error is a "not found" error. Returns false
|
||||||
|
// otherwise.
|
||||||
|
func IsErrNotFound(err error) bool {
|
||||||
|
return errors.Is(err, ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
// DB is the DB interface expected by the step-ca ACME API.
|
// DB is the DB interface expected by the step-ca ACME API.
|
||||||
type DB interface {
|
type DB interface {
|
||||||
CreateAccount(ctx context.Context, acc *Account) error
|
CreateAccount(ctx context.Context, acc *Account) error
|
||||||
|
|
|
@ -13,12 +13,14 @@ import (
|
||||||
|
|
||||||
// dbAccount represents an ACME account.
|
// dbAccount represents an ACME account.
|
||||||
type dbAccount struct {
|
type dbAccount struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Key *jose.JSONWebKey `json:"key"`
|
Key *jose.JSONWebKey `json:"key"`
|
||||||
Contact []string `json:"contact,omitempty"`
|
Contact []string `json:"contact,omitempty"`
|
||||||
Status acme.Status `json:"status"`
|
Status acme.Status `json:"status"`
|
||||||
CreatedAt time.Time `json:"createdAt"`
|
LocationPrefix string `json:"locationPrefix"`
|
||||||
DeactivatedAt time.Time `json:"deactivatedAt"`
|
ProvisionerName string `json:"provisionerName"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
DeactivatedAt time.Time `json:"deactivatedAt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dba *dbAccount) clone() *dbAccount {
|
func (dba *dbAccount) clone() *dbAccount {
|
||||||
|
@ -62,10 +64,12 @@ func (db *DB) GetAccount(ctx context.Context, id string) (*acme.Account, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &acme.Account{
|
return &acme.Account{
|
||||||
Status: dbacc.Status,
|
Status: dbacc.Status,
|
||||||
Contact: dbacc.Contact,
|
Contact: dbacc.Contact,
|
||||||
Key: dbacc.Key,
|
Key: dbacc.Key,
|
||||||
ID: dbacc.ID,
|
ID: dbacc.ID,
|
||||||
|
LocationPrefix: dbacc.LocationPrefix,
|
||||||
|
ProvisionerName: dbacc.ProvisionerName,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,11 +91,13 @@ func (db *DB) CreateAccount(ctx context.Context, acc *acme.Account) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
dba := &dbAccount{
|
dba := &dbAccount{
|
||||||
ID: acc.ID,
|
ID: acc.ID,
|
||||||
Key: acc.Key,
|
Key: acc.Key,
|
||||||
Contact: acc.Contact,
|
Contact: acc.Contact,
|
||||||
Status: acc.Status,
|
Status: acc.Status,
|
||||||
CreatedAt: clock.Now(),
|
CreatedAt: clock.Now(),
|
||||||
|
LocationPrefix: acc.LocationPrefix,
|
||||||
|
ProvisionerName: acc.ProvisionerName,
|
||||||
}
|
}
|
||||||
|
|
||||||
kid, err := acme.KeyToID(dba.Key)
|
kid, err := acme.KeyToID(dba.Key)
|
||||||
|
|
|
@ -197,6 +197,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
|
||||||
|
|
||||||
func TestDB_GetAccount(t *testing.T) {
|
func TestDB_GetAccount(t *testing.T) {
|
||||||
accID := "accID"
|
accID := "accID"
|
||||||
|
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
|
||||||
|
provisionerName := "foo"
|
||||||
type test struct {
|
type test struct {
|
||||||
db nosql.DB
|
db nosql.DB
|
||||||
err error
|
err error
|
||||||
|
@ -222,12 +224,14 @@ func TestDB_GetAccount(t *testing.T) {
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
dbacc := &dbAccount{
|
dbacc := &dbAccount{
|
||||||
ID: accID,
|
ID: accID,
|
||||||
Status: acme.StatusDeactivated,
|
Status: acme.StatusDeactivated,
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
DeactivatedAt: now,
|
DeactivatedAt: now,
|
||||||
Contact: []string{"foo", "bar"},
|
Contact: []string{"foo", "bar"},
|
||||||
Key: jwk,
|
Key: jwk,
|
||||||
|
LocationPrefix: locationPrefix,
|
||||||
|
ProvisionerName: provisionerName,
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(dbacc)
|
b, err := json.Marshal(dbacc)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
@ -266,6 +270,8 @@ func TestDB_GetAccount(t *testing.T) {
|
||||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||||
|
assert.Equals(t, acc.LocationPrefix, tc.dbacc.LocationPrefix)
|
||||||
|
assert.Equals(t, acc.ProvisionerName, tc.dbacc.ProvisionerName)
|
||||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -379,6 +385,7 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDB_CreateAccount(t *testing.T) {
|
func TestDB_CreateAccount(t *testing.T) {
|
||||||
|
locationPrefix := "https://test.ca.smallstep.com/acme/foo/account/"
|
||||||
type test struct {
|
type test struct {
|
||||||
db nosql.DB
|
db nosql.DB
|
||||||
acc *acme.Account
|
acc *acme.Account
|
||||||
|
@ -390,9 +397,10 @@ func TestDB_CreateAccount(t *testing.T) {
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
acc := &acme.Account{
|
acc := &acme.Account{
|
||||||
Status: acme.StatusValid,
|
Status: acme.StatusValid,
|
||||||
Contact: []string{"foo", "bar"},
|
Contact: []string{"foo", "bar"},
|
||||||
Key: jwk,
|
Key: jwk,
|
||||||
|
LocationPrefix: locationPrefix,
|
||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
db: &db.MockNoSQLDB{
|
db: &db.MockNoSQLDB{
|
||||||
|
@ -413,9 +421,10 @@ func TestDB_CreateAccount(t *testing.T) {
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
acc := &acme.Account{
|
acc := &acme.Account{
|
||||||
Status: acme.StatusValid,
|
Status: acme.StatusValid,
|
||||||
Contact: []string{"foo", "bar"},
|
Contact: []string{"foo", "bar"},
|
||||||
Key: jwk,
|
Key: jwk,
|
||||||
|
LocationPrefix: locationPrefix,
|
||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
db: &db.MockNoSQLDB{
|
db: &db.MockNoSQLDB{
|
||||||
|
@ -436,9 +445,10 @@ func TestDB_CreateAccount(t *testing.T) {
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
acc := &acme.Account{
|
acc := &acme.Account{
|
||||||
Status: acme.StatusValid,
|
Status: acme.StatusValid,
|
||||||
Contact: []string{"foo", "bar"},
|
Contact: []string{"foo", "bar"},
|
||||||
Key: jwk,
|
Key: jwk,
|
||||||
|
LocationPrefix: locationPrefix,
|
||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
db: &db.MockNoSQLDB{
|
db: &db.MockNoSQLDB{
|
||||||
|
@ -456,6 +466,8 @@ func TestDB_CreateAccount(t *testing.T) {
|
||||||
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||||
assert.Equals(t, dbacc.ID, string(key))
|
assert.Equals(t, dbacc.ID, string(key))
|
||||||
assert.Equals(t, dbacc.Contact, acc.Contact)
|
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||||
|
assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix)
|
||||||
|
assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName)
|
||||||
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||||
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||||
|
@ -479,9 +491,10 @@ func TestDB_CreateAccount(t *testing.T) {
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
acc := &acme.Account{
|
acc := &acme.Account{
|
||||||
Status: acme.StatusValid,
|
Status: acme.StatusValid,
|
||||||
Contact: []string{"foo", "bar"},
|
Contact: []string{"foo", "bar"},
|
||||||
Key: jwk,
|
Key: jwk,
|
||||||
|
LocationPrefix: locationPrefix,
|
||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
db: &db.MockNoSQLDB{
|
db: &db.MockNoSQLDB{
|
||||||
|
@ -500,6 +513,8 @@ func TestDB_CreateAccount(t *testing.T) {
|
||||||
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
assert.FatalError(t, json.Unmarshal(nu, dbacc))
|
||||||
assert.Equals(t, dbacc.ID, string(key))
|
assert.Equals(t, dbacc.ID, string(key))
|
||||||
assert.Equals(t, dbacc.Contact, acc.Contact)
|
assert.Equals(t, dbacc.Contact, acc.Contact)
|
||||||
|
assert.Equals(t, dbacc.LocationPrefix, acc.LocationPrefix)
|
||||||
|
assert.Equals(t, dbacc.ProvisionerName, acc.ProvisionerName)
|
||||||
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
assert.Equals(t, dbacc.Key.KeyID, acc.Key.KeyID)
|
||||||
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbacc.CreatedAt))
|
||||||
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
assert.True(t, clock.Now().Add(time.Minute).After(dbacc.CreatedAt))
|
||||||
|
@ -539,12 +554,14 @@ func TestDB_UpdateAccount(t *testing.T) {
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
dbacc := &dbAccount{
|
dbacc := &dbAccount{
|
||||||
ID: accID,
|
ID: accID,
|
||||||
Status: acme.StatusDeactivated,
|
Status: acme.StatusDeactivated,
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
DeactivatedAt: now,
|
DeactivatedAt: now,
|
||||||
Contact: []string{"foo", "bar"},
|
Contact: []string{"foo", "bar"},
|
||||||
Key: jwk,
|
LocationPrefix: "foo",
|
||||||
|
ProvisionerName: "alpha",
|
||||||
|
Key: jwk,
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(dbacc)
|
b, err := json.Marshal(dbacc)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
@ -644,10 +661,12 @@ func TestDB_UpdateAccount(t *testing.T) {
|
||||||
},
|
},
|
||||||
"ok": func(t *testing.T) test {
|
"ok": func(t *testing.T) test {
|
||||||
acc := &acme.Account{
|
acc := &acme.Account{
|
||||||
ID: accID,
|
ID: accID,
|
||||||
Status: acme.StatusDeactivated,
|
Status: acme.StatusDeactivated,
|
||||||
Contact: []string{"foo", "bar"},
|
Contact: []string{"baz", "zap"},
|
||||||
Key: jwk,
|
LocationPrefix: "bar",
|
||||||
|
ProvisionerName: "beta",
|
||||||
|
Key: jwk,
|
||||||
}
|
}
|
||||||
return test{
|
return test{
|
||||||
acc: acc,
|
acc: acc,
|
||||||
|
@ -666,7 +685,10 @@ func TestDB_UpdateAccount(t *testing.T) {
|
||||||
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
||||||
assert.Equals(t, dbNew.ID, dbacc.ID)
|
assert.Equals(t, dbNew.ID, dbacc.ID)
|
||||||
assert.Equals(t, dbNew.Status, acc.Status)
|
assert.Equals(t, dbNew.Status, acc.Status)
|
||||||
assert.Equals(t, dbNew.Contact, dbacc.Contact)
|
assert.Equals(t, dbNew.Contact, acc.Contact)
|
||||||
|
// LocationPrefix should not change.
|
||||||
|
assert.Equals(t, dbNew.LocationPrefix, dbacc.LocationPrefix)
|
||||||
|
assert.Equals(t, dbNew.ProvisionerName, dbacc.ProvisionerName)
|
||||||
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
|
assert.Equals(t, dbNew.Key.KeyID, dbacc.Key.KeyID)
|
||||||
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
|
assert.Equals(t, dbNew.CreatedAt, dbacc.CreatedAt)
|
||||||
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
|
assert.True(t, dbNew.DeactivatedAt.Add(-time.Minute).Before(now))
|
||||||
|
@ -686,12 +708,7 @@ func TestDB_UpdateAccount(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if assert.Nil(t, tc.err) {
|
assert.Nil(t, tc.err)
|
||||||
assert.Equals(t, tc.acc.ID, dbacc.ID)
|
|
||||||
assert.Equals(t, tc.acc.Status, dbacc.Status)
|
|
||||||
assert.Equals(t, tc.acc.Contact, dbacc.Contact)
|
|
||||||
assert.Equals(t, tc.acc.Key.KeyID, dbacc.Key.KeyID)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue