add authorization and order unit tests

This commit is contained in:
max furman 2021-03-24 16:50:35 -07:00
parent a58466589f
commit c0a9f24798
6 changed files with 473 additions and 1570 deletions

View file

@ -1,764 +1,81 @@
package acme package acme
import ( import (
"fmt" "crypto"
"time" "encoding/base64"
"testing"
"github.com/smallstep/certificates/authority/provisioner" "github.com/pkg/errors"
"github.com/smallstep/assert"
"go.step.sm/crypto/jose"
) )
var ( func TestKeyToID(t *testing.T) {
defaultDisableRenewal = false type test struct {
globalProvisionerClaims = provisioner.Claims{ jwk *jose.JSONWebKey
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, exp string
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, err *Error
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
} }
) tests := map[string]func(t *testing.T) test{
"fail/error-generating-thumbprint": func(t *testing.T) test {
func newProv() Provisioner {
// Initialize provisioners
p := &provisioner.ACME{
Type: "ACME",
Name: "test@acme-provisioner.com",
}
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
fmt.Printf("%v", err)
}
return p
}
/*
func newAcc() (*Account, error) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
if err != nil { assert.FatalError(t, err)
return nil, err jwk.Key = "foo"
return test{
jwk: jwk,
err: NewErrorISE("error generating jwk thumbprint: square/go-jose: unknown key type 'string'"),
}
},
"ok": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
kid, err := jwk.Thumbprint(crypto.SHA256)
assert.FatalError(t, err)
return test{
jwk: jwk,
exp: base64.RawURLEncoding.EncodeToString(kid),
} }
mockdb := &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, true, nil
}, },
} }
return newAccount(mockdb, AccountOptions{ for name, run := range tests {
Key: jwk, Contact: []string{"foo", "bar"}, t.Run(name, func(t *testing.T) {
tc := run(t)
if id, err := KeyToID(tc.jwk); err != nil {
if assert.NotNil(t, tc.err) {
switch k := err.(type) {
case *Error:
assert.Equals(t, k.Type, tc.err.Type)
assert.Equals(t, k.Detail, tc.err.Detail)
assert.Equals(t, k.Status, tc.err.Status)
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
assert.Equals(t, k.Detail, tc.err.Detail)
default:
assert.FatalError(t, errors.New("unexpected error type"))
}
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, id, tc.exp)
}
}
}) })
} }
*/ }
/* func TestAccount_IsValid(t *testing.T) {
func TestGetAccountByID(t *testing.T) {
type test struct { type test struct {
id string
db nosql.DB
acc *Account acc *Account
err *Error exp bool
} }
tests := map[string]func(t *testing.T) test{ tests := map[string]test{
"fail/not-found": func(t *testing.T) test { "valid": {acc: &Account{Status: StatusValid}, exp: true},
acc, err := newAcc() "invalid": {acc: &Account{Status: StatusInvalid}, exp: false},
assert.FatalError(t, err)
return test{
acc: acc,
id: acc.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
},
err: NewError(ErrorMalformedType, "account %s not found: not found", acc.ID),
} }
}, for name, tc := range tests {
"fail/db-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
return test{
acc: acc,
id: acc.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("force")
},
},
err: NewErrorISE("error loading account %s: force", acc.ID),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
return test{
acc: acc,
id: acc.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
return nil, nil
},
},
err: ServerInternalErr(errors.New("error unmarshaling account: unexpected end of JSON input")),
}
},
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
return test{
acc: acc,
id: acc.ID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
return b, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) assert.Equals(t, tc.acc.IsValid(), tc.exp)
if acc, err := getAccountByID(tc.db, tc.id); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.acc.ID, acc.ID)
assert.Equals(t, tc.acc.Status, acc.Status)
assert.Equals(t, tc.acc.Created, acc.Created)
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
assert.Equals(t, tc.acc.Contact, acc.Contact)
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
}
}
}) })
} }
} }
func TestGetAccountByKeyID(t *testing.T) {
type test struct {
kid string
db nosql.DB
acc *account
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/kid-not-found": func(t *testing.T) test {
return test{
kid: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
},
err: MalformedErr(errors.Errorf("account with key id foo not found: not found")),
}
},
"fail/db-error": func(t *testing.T) test {
return test{
kid: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error loading key-account index: force")),
}
},
"fail/getAccount-error": func(t *testing.T) test {
count := 0
return test{
kid: "foo",
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if count == 0 {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte("foo"))
count++
return []byte("bar"), nil
}
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error loading account bar: force")),
}
},
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
count := 0
return test{
kid: acc.Key.KeyID,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
var ret []byte
switch count {
case 0:
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(acc.Key.KeyID))
ret = []byte(acc.ID)
case 1:
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
ret = b
}
count++
return ret, nil
},
},
acc: acc,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if acc, err := getAccountByKeyID(tc.db, tc.kid); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.acc.ID, acc.ID)
assert.Equals(t, tc.acc.Status, acc.Status)
assert.Equals(t, tc.acc.Created, acc.Created)
assert.Equals(t, tc.acc.Deactivated, acc.Deactivated)
assert.Equals(t, tc.acc.Contact, acc.Contact)
assert.Equals(t, tc.acc.Key.KeyID, acc.Key.KeyID)
}
}
})
}
}
func TestAccountToACME(t *testing.T) {
dir := newDirectory("ca.smallstep.com", "acme")
prov := newProv()
provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, baseURL)
type test struct {
acc *account
err *Error
}
tests := map[string]func(t *testing.T) test{
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
return test{acc: acc}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
acmeAccount, err := tc.acc.toACME(ctx, nil, dir)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, acmeAccount.ID, tc.acc.ID)
assert.Equals(t, acmeAccount.Status, tc.acc.Status)
assert.Equals(t, acmeAccount.Contact, tc.acc.Contact)
assert.Equals(t, acmeAccount.Key.KeyID, tc.acc.Key.KeyID)
assert.Equals(t, acmeAccount.Orders,
fmt.Sprintf("%s/acme/%s/account/%s/orders", baseURL.String(), provName, tc.acc.ID))
}
}
})
}
}
func TestAccountSave(t *testing.T) {
type test struct {
acc, old *account
db nosql.DB
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/old-nil/swap-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
return test{
acc: acc,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error storing account: force")),
}
},
"fail/old-nil/swap-false": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
return test{
acc: acc,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return []byte("foo"), false, nil
},
},
err: ServerInternalErr(errors.New("error storing account; value has changed since last read")),
}
},
"ok/old-nil": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
return test{
acc: acc,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, old, nil)
assert.Equals(t, b, newval)
assert.Equals(t, bucket, accountTable)
assert.Equals(t, []byte(acc.ID), key)
return nil, true, nil
},
},
}
},
"ok/old-not-nil": func(t *testing.T) test {
oldAcc, err := newAcc()
assert.FatalError(t, err)
acc, err := newAcc()
assert.FatalError(t, err)
oldb, err := json.Marshal(oldAcc)
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
return test{
acc: acc,
old: oldAcc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, old, oldb)
assert.Equals(t, newval, b)
assert.Equals(t, bucket, accountTable)
assert.Equals(t, []byte(acc.ID), key)
return []byte("foo"), true, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := tc.acc.save(tc.db, tc.old); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestAccountSaveNew(t *testing.T) {
type test struct {
acc *account
db nosql.DB
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/keyToID-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
acc.Key.Key = "foo"
return test{
acc: acc,
err: ServerInternalErr(errors.New("error generating jwk thumbprint: square/go-jose: unknown key type 'string'")),
}
},
"fail/swap-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
kid, err := keyToID(acc.Key)
assert.FatalError(t, err)
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
assert.Equals(t, old, nil)
assert.Equals(t, newval, []byte(acc.ID))
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
}
},
"fail/swap-false": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
kid, err := keyToID(acc.Key)
assert.FatalError(t, err)
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
assert.Equals(t, old, nil)
assert.Equals(t, newval, []byte(acc.ID))
return nil, false, nil
},
},
err: ServerInternalErr(errors.New("key-id to account-id index already exists")),
}
},
"fail/save-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
kid, err := keyToID(acc.Key)
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
count := 0
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 0 {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
assert.Equals(t, old, nil)
assert.Equals(t, newval, []byte(acc.ID))
count++
return nil, true, nil
}
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, nil)
assert.Equals(t, newval, b)
return nil, false, errors.New("force")
},
MDel: func(bucket, key []byte) error {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
return nil
},
},
err: ServerInternalErr(errors.New("error storing account: force")),
}
},
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
kid, err := keyToID(acc.Key)
assert.FatalError(t, err)
b, err := json.Marshal(acc)
assert.FatalError(t, err)
count := 0
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 0 {
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
assert.Equals(t, old, nil)
assert.Equals(t, newval, []byte(acc.ID))
count++
return nil, true, nil
}
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, nil)
assert.Equals(t, newval, b)
return nil, true, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := tc.acc.saveNew(tc.db); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestAccountUpdate(t *testing.T) {
type test struct {
acc *account
contact []string
db nosql.DB
res []byte
err *Error
}
contact := []string{"foo", "bar"}
tests := map[string]func(t *testing.T) test{
"fail/save-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
oldb, err := json.Marshal(acc)
assert.FatalError(t, err)
_acc := *acc
clone := &_acc
clone.Contact = contact
b, err := json.Marshal(clone)
assert.FatalError(t, err)
return test{
acc: acc,
contact: contact,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, oldb)
assert.Equals(t, newval, b)
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error storing account: force")),
}
},
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
oldb, err := json.Marshal(acc)
assert.FatalError(t, err)
_acc := *acc
clone := &_acc
clone.Contact = contact
b, err := json.Marshal(clone)
assert.FatalError(t, err)
return test{
acc: acc,
contact: contact,
res: b,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, oldb)
assert.Equals(t, newval, b)
return nil, true, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
acc, err := tc.acc.update(tc.db, tc.contact)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
b, err := json.Marshal(acc)
assert.FatalError(t, err)
assert.Equals(t, b, tc.res)
}
}
})
}
}
func TestAccountDeactivate(t *testing.T) {
type test struct {
acc *account
db nosql.DB
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/save-error": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
oldb, err := json.Marshal(acc)
assert.FatalError(t, err)
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, oldb)
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error storing account: force")),
}
},
"ok": func(t *testing.T) test {
acc, err := newAcc()
assert.FatalError(t, err)
oldb, err := json.Marshal(acc)
assert.FatalError(t, err)
return test{
acc: acc,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, accountTable)
assert.Equals(t, key, []byte(acc.ID))
assert.Equals(t, old, oldb)
return nil, true, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
acc, err := tc.acc.deactivate(tc.db)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, acc.ID, tc.acc.ID)
assert.Equals(t, acc.Contact, tc.acc.Contact)
assert.Equals(t, acc.Status, StatusDeactivated)
assert.Equals(t, acc.Key.KeyID, tc.acc.Key.KeyID)
assert.Equals(t, acc.Created, tc.acc.Created)
assert.True(t, acc.Deactivated.Before(time.Now().Add(time.Minute)))
assert.True(t, acc.Deactivated.After(time.Now().Add(-time.Minute)))
}
}
})
}
}
func TestNewAccount(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
kid, err := keyToID(jwk)
assert.FatalError(t, err)
ops := AccountOptions{
Key: jwk,
Contact: []string{"foo", "bar"},
}
type test struct {
ops AccountOptions
db nosql.DB
err *Error
id *string
}
tests := map[string]func(t *testing.T) test{
"fail/store-error": func(t *testing.T) test {
return test{
ops: ops,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error setting key-id to account-id index: force")),
}
},
"ok": func(t *testing.T) test {
var _id string
id := &_id
count := 0
return test{
ops: ops,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
switch count {
case 0:
assert.Equals(t, bucket, accountByKeyIDTable)
assert.Equals(t, key, []byte(kid))
case 1:
assert.Equals(t, bucket, accountTable)
*id = string(key)
}
count++
return nil, true, nil
},
},
id: id,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
acc, err := newAccount(tc.db, tc.ops)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, acc.ID, *tc.id)
assert.Equals(t, acc.Status, StatusValid)
assert.Equals(t, acc.Contact, ops.Contact)
assert.Equals(t, acc.Key.KeyID, ops.Key.KeyID)
assert.True(t, acc.Deactivated.IsZero())
assert.True(t, acc.Created.Before(time.Now().UTC().Add(time.Minute)))
assert.True(t, acc.Created.After(time.Now().UTC().Add(-1*time.Minute)))
}
}
})
}
}
*/

View file

@ -57,6 +57,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
return nil return nil
} }
az.Status = StatusValid az.Status = StatusValid
az.Error = nil
default: default:
return NewErrorISE("unrecognized authorization status: %s", az.Status) return NewErrorISE("unrecognized authorization status: %s", az.Status)
} }

150
acme/authorization_test.go Normal file
View file

@ -0,0 +1,150 @@
package acme
import (
"context"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
)
func TestAuthorization_UpdateStatus(t *testing.T) {
type test struct {
az *Authorization
err *Error
db DB
}
tests := map[string]func(t *testing.T) test{
"ok/already-invalid": func(t *testing.T) test {
az := &Authorization{
Status: StatusInvalid,
}
return test{
az: az,
}
},
"ok/already-valid": func(t *testing.T) test {
az := &Authorization{
Status: StatusInvalid,
}
return test{
az: az,
}
},
"fail/error-unexpected-status": func(t *testing.T) test {
az := &Authorization{
Status: "foo",
}
return test{
az: az,
err: NewErrorISE("unrecognized authorization status: %s", az.Status),
}
},
"ok/expired": func(t *testing.T) test {
now := clock.Now()
az := &Authorization{
ID: "azID",
AccountID: "accID",
Status: StatusPending,
ExpiresAt: now.Add(-5 * time.Minute),
}
return test{
az: az,
db: &MockDB{
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
assert.Equals(t, updaz.ID, az.ID)
assert.Equals(t, updaz.AccountID, az.AccountID)
assert.Equals(t, updaz.Status, StatusInvalid)
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
return nil
},
},
}
},
"fail/db.UpdateAuthorization-error": func(t *testing.T) test {
now := clock.Now()
az := &Authorization{
ID: "azID",
AccountID: "accID",
Status: StatusPending,
ExpiresAt: now.Add(-5 * time.Minute),
}
return test{
az: az,
db: &MockDB{
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
assert.Equals(t, updaz.ID, az.ID)
assert.Equals(t, updaz.AccountID, az.AccountID)
assert.Equals(t, updaz.Status, StatusInvalid)
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
return errors.New("force")
},
},
err: NewErrorISE("error updating authorization: force"),
}
},
"ok/no-valid-challenges": func(t *testing.T) test {
now := clock.Now()
az := &Authorization{
ID: "azID",
AccountID: "accID",
Status: StatusPending,
ExpiresAt: now.Add(5 * time.Minute),
Challenges: []*Challenge{
{Status: StatusPending}, {Status: StatusPending}, {Status: StatusPending},
},
}
return test{
az: az,
}
},
"ok/valid": func(t *testing.T) test {
now := clock.Now()
az := &Authorization{
ID: "azID",
AccountID: "accID",
Status: StatusPending,
ExpiresAt: now.Add(5 * time.Minute),
Challenges: []*Challenge{
{Status: StatusPending}, {Status: StatusPending}, {Status: StatusValid},
},
}
return test{
az: az,
db: &MockDB{
MockUpdateAuthorization: func(ctx context.Context, updaz *Authorization) error {
assert.Equals(t, updaz.ID, az.ID)
assert.Equals(t, updaz.AccountID, az.AccountID)
assert.Equals(t, updaz.Status, StatusValid)
assert.Equals(t, updaz.ExpiresAt, az.ExpiresAt)
assert.Equals(t, updaz.Error, nil)
return nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := tc.az.UpdateStatus(context.Background(), tc.db); err != nil {
if assert.NotNil(t, tc.err) {
switch k := err.(type) {
case *Error:
assert.Equals(t, k.Type, tc.err.Type)
assert.Equals(t, k.Detail, tc.err.Detail)
assert.Equals(t, k.Status, tc.err.Status)
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
assert.Equals(t, k.Detail, tc.err.Detail)
default:
assert.FatalError(t, errors.New("unexpected error type"))
}
}
} else {
assert.Nil(t, tc.err)
}
})
}
}

View file

@ -1,824 +0,0 @@
package acme
/*
func newAz() (*Authorization, error) {
mockdb := &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return []byte("foo"), true, nil
},
}
return newAuthz(mockdb, "1234", Identifier{
Type: "dns", Value: "acme.example.com",
})
}
func TestGetAuthz(t *testing.T) {
type test struct {
id string
db nosql.DB
az authz
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
return test{
az: az,
id: az.getID(),
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, database.ErrNotFound
},
},
err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())),
}
},
"fail/db-error": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
return test{
az: az,
id: az.getID(),
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Identifier.Type = "foo"
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
az: az,
id: az.getID(),
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authzTable)
assert.Equals(t, key, []byte(az.getID()))
return b, nil
},
},
err: ServerInternalErr(errors.New("unexpected authz type foo")),
}
},
"ok": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
az: az,
id: az.getID(),
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, authzTable)
assert.Equals(t, key, []byte(az.getID()))
return b, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if az, err := getAuthz(tc.db, tc.id); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.az.getID(), az.getID())
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
assert.Equals(t, tc.az.getStatus(), az.getStatus())
assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier())
assert.Equals(t, tc.az.getCreated(), az.getCreated())
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
}
}
})
}
}
func TestAuthzClone(t *testing.T) {
az, err := newAz()
assert.FatalError(t, err)
clone := az.clone()
assert.Equals(t, clone.getID(), az.getID())
assert.Equals(t, clone.getAccountID(), az.getAccountID())
assert.Equals(t, clone.getStatus(), az.getStatus())
assert.Equals(t, clone.getIdentifier(), az.getIdentifier())
assert.Equals(t, clone.getExpiry(), az.getExpiry())
assert.Equals(t, clone.getCreated(), az.getCreated())
assert.Equals(t, clone.getChallenges(), az.getChallenges())
clone.Status = StatusValid
assert.NotEquals(t, clone.getStatus(), az.getStatus())
}
func TestNewAuthz(t *testing.T) {
iden := Identifier{
Type: "dns", Value: "acme.example.com",
}
accID := "1234"
type test struct {
iden Identifier
db nosql.DB
err *Error
resChs *([]string)
}
tests := map[string]func(t *testing.T) test{
"fail/unexpected-type": func(t *testing.T) test {
return test{
iden: Identifier{Type: "foo", Value: "acme.example.com"},
err: MalformedErr(errors.New("unexpected authz type foo")),
}
},
"fail/new-http-chall-error": func(t *testing.T) test {
return test{
iden: iden,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")),
}
},
"fail/new-tls-alpn-chall-error": func(t *testing.T) test {
count := 0
return test{
iden: iden,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 1 {
return nil, false, errors.New("force")
}
count++
return nil, true, nil
},
},
err: ServerInternalErr(errors.New("error creating alpn challenge: error saving acme challenge: force")),
}
},
"fail/new-dns-chall-error": func(t *testing.T) test {
count := 0
return test{
iden: iden,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 2 {
return nil, false, errors.New("force")
}
count++
return nil, true, nil
},
},
err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")),
}
},
"fail/save-authz-error": func(t *testing.T) test {
count := 0
return test{
iden: iden,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 3 {
return nil, false, errors.New("force")
}
count++
return nil, true, nil
},
},
err: ServerInternalErr(errors.New("error storing authz: force")),
}
},
"ok": func(t *testing.T) test {
chs := &([]string{})
count := 0
return test{
iden: iden,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 3 {
assert.Equals(t, bucket, authzTable)
assert.Equals(t, old, nil)
az, err := unmarshalAuthz(newval)
assert.FatalError(t, err)
assert.Equals(t, az.getID(), string(key))
assert.Equals(t, az.getAccountID(), accID)
assert.Equals(t, az.getStatus(), StatusPending)
assert.Equals(t, az.getIdentifier(), iden)
assert.Equals(t, az.getWildcard(), false)
*chs = az.getChallenges()
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
expiry := az.getCreated().Add(defaultExpiryDuration)
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
}
count++
return nil, true, nil
},
},
resChs: chs,
}
},
"ok/wildcard": func(t *testing.T) test {
chs := &([]string{})
count := 0
_iden := Identifier{Type: "dns", Value: "*.acme.example.com"}
return test{
iden: _iden,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 1 {
assert.Equals(t, bucket, authzTable)
assert.Equals(t, old, nil)
az, err := unmarshalAuthz(newval)
assert.FatalError(t, err)
assert.Equals(t, az.getID(), string(key))
assert.Equals(t, az.getAccountID(), accID)
assert.Equals(t, az.getStatus(), StatusPending)
assert.Equals(t, az.getIdentifier(), iden)
assert.Equals(t, az.getWildcard(), true)
*chs = az.getChallenges()
// Verify that we only have 1 challenge instead of 2.
assert.True(t, len(*chs) == 1)
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
expiry := az.getCreated().Add(defaultExpiryDuration)
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
}
count++
return nil, true, nil
},
},
resChs: chs,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
az, err := newAuthz(tc.db, accID, tc.iden)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, az.getAccountID(), accID)
assert.Equals(t, az.getType(), "dns")
assert.Equals(t, az.getStatus(), StatusPending)
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
expiry := az.getCreated().Add(defaultExpiryDuration)
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
assert.Equals(t, az.getChallenges(), *(tc.resChs))
if strings.HasPrefix(tc.iden.Value, "*.") {
assert.True(t, az.getWildcard())
assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*."))
} else {
assert.False(t, az.getWildcard())
assert.Equals(t, az.getIdentifier().Value, tc.iden.Value)
}
assert.True(t, az.getID() != "")
}
}
})
}
}
func TestAuthzToACME(t *testing.T) {
dir := newDirectory("ca.smallstep.com", "acme")
var (
ch1, ch2 challenge
ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
err error
)
count := 0
mockdb := &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 0 {
*ch1Bytes = newval
ch1, err = unmarshalChallenge(newval)
assert.FatalError(t, err)
} else if count == 1 {
*ch2Bytes = newval
ch2, err = unmarshalChallenge(newval)
assert.FatalError(t, err)
}
count++
return []byte("foo"), true, nil
},
}
iden := Identifier{
Type: "dns", Value: "acme.example.com",
}
az, err := newAuthz(mockdb, "1234", iden)
assert.FatalError(t, err)
prov := newProv()
ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov)
ctx = context.WithValue(ctx, BaseURLContextKey, "https://test.ca.smallstep.com:8080")
type test struct {
db nosql.DB
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/getChallenge1-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error loading challenge")),
}
},
"fail/getChallenge2-error": func(t *testing.T) test {
count := 0
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if count == 1 {
return nil, errors.New("force")
}
count++
return *ch1Bytes, nil
},
},
err: ServerInternalErr(errors.New("error loading challenge")),
}
},
"ok": func(t *testing.T) test {
count := 0
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if count == 0 {
count++
return *ch1Bytes, nil
}
return *ch2Bytes, nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
acmeAz, err := az.toACME(ctx, tc.db, dir)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, acmeAz.ID, az.getID())
assert.Equals(t, acmeAz.Identifier, iden)
assert.Equals(t, acmeAz.Status, StatusPending)
acmeCh1, err := ch1.toACME(ctx, nil, dir)
assert.FatalError(t, err)
acmeCh2, err := ch2.toACME(ctx, nil, dir)
assert.FatalError(t, err)
assert.Equals(t, acmeAz.Challenges[0], acmeCh1)
assert.Equals(t, acmeAz.Challenges[1], acmeCh2)
expiry, err := time.Parse(time.RFC3339, acmeAz.Expires)
assert.FatalError(t, err)
assert.Equals(t, expiry.String(), az.getExpiry().String())
}
}
})
}
}
func TestAuthzSave(t *testing.T) {
type test struct {
az, old authz
db nosql.DB
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/old-nil/swap-error": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
return test{
az: az,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error storing authz: force")),
}
},
"fail/old-nil/swap-false": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
return test{
az: az,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return []byte("foo"), false, nil
},
},
err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")),
}
},
"ok/old-nil": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
az: az,
old: nil,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, old, nil)
assert.Equals(t, b, newval)
assert.Equals(t, bucket, authzTable)
assert.Equals(t, []byte(az.getID()), key)
return nil, true, nil
},
},
}
},
"ok/old-not-nil": func(t *testing.T) test {
oldAz, err := newAz()
assert.FatalError(t, err)
az, err := newAz()
assert.FatalError(t, err)
oldb, err := json.Marshal(oldAz)
assert.FatalError(t, err)
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
az: az,
old: oldAz,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
assert.Equals(t, old, oldb)
assert.Equals(t, b, newval)
assert.Equals(t, bucket, authzTable)
assert.Equals(t, []byte(az.getID()), key)
return []byte("foo"), true, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := tc.az.save(tc.db, tc.old); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
func TestAuthzUnmarshal(t *testing.T) {
type test struct {
az authz
azb []byte
err *Error
}
tests := map[string]func(t *testing.T) test{
"fail/nil": func(t *testing.T) test {
return test{
azb: nil,
err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")),
}
},
"fail/unexpected-type": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Identifier.Type = "foo"
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
azb: b,
err: ServerInternalErr(errors.New("unexpected authz type foo")),
}
},
"ok/dns": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
b, err := json.Marshal(az)
assert.FatalError(t, err)
return test{
az: az,
azb: b,
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if az, err := unmarshalAuthz(tc.azb); err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.az.getID(), az.getID())
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
assert.Equals(t, tc.az.getStatus(), az.getStatus())
assert.Equals(t, tc.az.getCreated(), az.getCreated())
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
assert.Equals(t, tc.az.getWildcard(), az.getWildcard())
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
}
}
})
}
}
func TestAuthzUpdateStatus(t *testing.T) {
type test struct {
az, res authz
err *Error
db nosql.DB
}
tests := map[string]func(t *testing.T) test{
"fail/already-invalid": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Status = StatusInvalid
return test{
az: az,
res: az,
}
},
"fail/already-valid": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Status = StatusValid
return test{
az: az,
res: az,
}
},
"fail/unexpected-status": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Status = StatusReady
return test{
az: az,
res: az,
err: ServerInternalErr(errors.New("unrecognized authz status: ready")),
}
},
"fail/save-error": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
return test{
az: az,
res: az,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, false, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error storing authz: force")),
}
},
"ok/expired": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
clone := az.clone()
clone.Error = MalformedErr(errors.New("authz has expired"))
clone.Status = StatusInvalid
return test{
az: az,
res: clone.parent(),
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, true, nil
},
},
}
},
"fail/get-challenge-error": func(t *testing.T) test {
az, err := newAz()
assert.FatalError(t, err)
return test{
az: az,
res: az,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
return nil, errors.New("force")
},
},
err: ServerInternalErr(errors.New("error loading challenge")),
}
},
"ok/valid": func(t *testing.T) test {
var (
ch3 challenge
ch2Bytes = &([]byte{})
ch1Bytes = &([]byte{})
err error
)
count := 0
mockdb := &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 0 {
*ch1Bytes = newval
} else if count == 1 {
*ch2Bytes = newval
} else if count == 2 {
ch3, err = unmarshalChallenge(newval)
assert.FatalError(t, err)
}
count++
return nil, true, nil
},
}
iden := Identifier{
Type: "dns", Value: "acme.example.com",
}
az, err := newAuthz(mockdb, "1234", iden)
assert.FatalError(t, err)
_az, ok := az.(*dnsAuthz)
assert.Fatal(t, ok)
_az.baseAuthz.Error = MalformedErr(nil)
_ch, ok := ch3.(*dns01Challenge)
assert.Fatal(t, ok)
_ch.baseChallenge.Status = StatusValid
chb, err := json.Marshal(ch3)
clone := az.clone()
clone.Status = StatusValid
clone.Error = nil
count = 0
return test{
az: az,
res: clone.parent(),
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if count == 0 {
count++
return *ch1Bytes, nil
}
if count == 1 {
count++
return *ch2Bytes, nil
}
count++
return chb, nil
},
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
return nil, true, nil
},
},
}
},
"ok/still-pending": func(t *testing.T) test {
var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
count := 0
mockdb := &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
if count == 0 {
*ch1Bytes = newval
} else if count == 1 {
*ch2Bytes = newval
}
count++
return nil, true, nil
},
}
iden := Identifier{
Type: "dns", Value: "acme.example.com",
}
az, err := newAuthz(mockdb, "1234", iden)
assert.FatalError(t, err)
count = 0
return test{
az: az,
res: az,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
if count == 0 {
count++
return *ch1Bytes, nil
}
count++
return *ch2Bytes, nil
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
az, err := tc.az.updateStatus(tc.db)
if err != nil {
if assert.NotNil(t, tc.err) {
ae, ok := err.(*Error)
assert.True(t, ok)
assert.HasPrefix(t, ae.Error(), tc.err.Error())
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
assert.Equals(t, ae.Type, tc.err.Type)
}
} else {
if assert.Nil(t, tc.err) {
expB, err := json.Marshal(tc.res)
assert.FatalError(t, err)
b, err := json.Marshal(az)
assert.FatalError(t, err)
assert.Equals(t, expB, b)
}
}
})
}
}
*/

View file

@ -81,10 +81,10 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
for _, azID := range o.AuthorizationIDs { for _, azID := range o.AuthorizationIDs {
az, err := db.GetAuthorization(ctx, azID) az, err := db.GetAuthorization(ctx, azID)
if err != nil { if err != nil {
return err return WrapErrorISE(err, "error getting authorization ID %s", azID)
} }
if err = az.UpdateStatus(ctx, db); err != nil { if err = az.UpdateStatus(ctx, db); err != nil {
return err return WrapErrorISE(err, "error updating authorization ID %s", azID)
} }
st := az.Status st := az.Status
count[st]++ count[st]++
@ -107,7 +107,10 @@ func (o *Order) UpdateStatus(ctx context.Context, db DB) error {
default: default:
return NewErrorISE("unrecognized order status: %s", o.Status) return NewErrorISE("unrecognized order status: %s", o.Status)
} }
return db.UpdateOrder(ctx, o) if err := db.UpdateOrder(ctx, o); err != nil {
return WrapErrorISE(err, "error updating order")
}
return nil
} }
// Finalize signs a certificate if the necessary conditions for Order completion // Finalize signs a certificate if the necessary conditions for Order completion

View file

@ -1,5 +1,261 @@
package acme package acme
import (
"context"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
)
func TestOrder_UpdateStatus(t *testing.T) {
type test struct {
o *Order
err *Error
db DB
}
tests := map[string]func(t *testing.T) test{
"ok/already-invalid": func(t *testing.T) test {
o := &Order{
Status: StatusInvalid,
}
return test{
o: o,
}
},
"ok/already-valid": func(t *testing.T) test {
o := &Order{
Status: StatusInvalid,
}
return test{
o: o,
}
},
"fail/error-unexpected-status": func(t *testing.T) test {
o := &Order{
Status: "foo",
}
return test{
o: o,
err: NewErrorISE("unrecognized order status: %s", o.Status),
}
},
"ok/ready-expired": func(t *testing.T) test {
now := clock.Now()
o := &Order{
ID: "oID",
AccountID: "accID",
Status: StatusReady,
ExpiresAt: now.Add(-5 * time.Minute),
}
return test{
o: o,
db: &MockDB{
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
assert.Equals(t, updo.ID, o.ID)
assert.Equals(t, updo.AccountID, o.AccountID)
assert.Equals(t, updo.Status, StatusInvalid)
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
return nil
},
},
}
},
"fail/ready-expired-db.UpdateOrder-error": func(t *testing.T) test {
now := clock.Now()
o := &Order{
ID: "oID",
AccountID: "accID",
Status: StatusReady,
ExpiresAt: now.Add(-5 * time.Minute),
}
return test{
o: o,
db: &MockDB{
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
assert.Equals(t, updo.ID, o.ID)
assert.Equals(t, updo.AccountID, o.AccountID)
assert.Equals(t, updo.Status, StatusInvalid)
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
return errors.New("force")
},
},
err: NewErrorISE("error updating order: force"),
}
},
"ok/pending-expired": func(t *testing.T) test {
now := clock.Now()
o := &Order{
ID: "oID",
AccountID: "accID",
Status: StatusPending,
ExpiresAt: now.Add(-5 * time.Minute),
}
return test{
o: o,
db: &MockDB{
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
assert.Equals(t, updo.ID, o.ID)
assert.Equals(t, updo.AccountID, o.AccountID)
assert.Equals(t, updo.Status, StatusInvalid)
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
err := NewError(ErrorMalformedType, "order has expired")
assert.HasPrefix(t, updo.Error.Err.Error(), err.Err.Error())
assert.Equals(t, updo.Error.Type, err.Type)
assert.Equals(t, updo.Error.Detail, err.Detail)
assert.Equals(t, updo.Error.Status, err.Status)
assert.Equals(t, updo.Error.Detail, err.Detail)
return nil
},
},
}
},
"ok/invalid": func(t *testing.T) test {
now := clock.Now()
o := &Order{
ID: "oID",
AccountID: "accID",
Status: StatusPending,
ExpiresAt: now.Add(5 * time.Minute),
AuthorizationIDs: []string{"a", "b"},
}
az1 := &Authorization{
ID: "a",
Status: StatusValid,
}
az2 := &Authorization{
ID: "b",
Status: StatusInvalid,
}
return test{
o: o,
db: &MockDB{
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
assert.Equals(t, updo.ID, o.ID)
assert.Equals(t, updo.AccountID, o.AccountID)
assert.Equals(t, updo.Status, StatusInvalid)
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
return nil
},
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
switch id {
case az1.ID:
return az1, nil
case az2.ID:
return az2, nil
default:
assert.FatalError(t, errors.Errorf("unexpected authz key %s", id))
return nil, errors.New("force")
}
},
},
}
},
"ok/still-pending": func(t *testing.T) test {
now := clock.Now()
o := &Order{
ID: "oID",
AccountID: "accID",
Status: StatusPending,
ExpiresAt: now.Add(5 * time.Minute),
AuthorizationIDs: []string{"a", "b"},
}
az1 := &Authorization{
ID: "a",
Status: StatusValid,
}
az2 := &Authorization{
ID: "b",
Status: StatusPending,
}
return test{
o: o,
db: &MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
switch id {
case az1.ID:
return az1, nil
case az2.ID:
return az2, nil
default:
assert.FatalError(t, errors.Errorf("unexpected authz key %s", id))
return nil, errors.New("force")
}
},
},
}
},
"ok/valid": func(t *testing.T) test {
now := clock.Now()
o := &Order{
ID: "oID",
AccountID: "accID",
Status: StatusPending,
ExpiresAt: now.Add(5 * time.Minute),
AuthorizationIDs: []string{"a", "b"},
}
az1 := &Authorization{
ID: "a",
Status: StatusValid,
}
az2 := &Authorization{
ID: "b",
Status: StatusValid,
}
return test{
o: o,
db: &MockDB{
MockUpdateOrder: func(ctx context.Context, updo *Order) error {
assert.Equals(t, updo.ID, o.ID)
assert.Equals(t, updo.AccountID, o.AccountID)
assert.Equals(t, updo.Status, StatusReady)
assert.Equals(t, updo.ExpiresAt, o.ExpiresAt)
return nil
},
MockGetAuthorization: func(ctx context.Context, id string) (*Authorization, error) {
switch id {
case az1.ID:
return az1, nil
case az2.ID:
return az2, nil
default:
assert.FatalError(t, errors.Errorf("unexpected authz key %s", id))
return nil, errors.New("force")
}
},
},
}
},
}
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
if err := tc.o.UpdateStatus(context.Background(), tc.db); err != nil {
if assert.NotNil(t, tc.err) {
switch k := err.(type) {
case *Error:
assert.Equals(t, k.Type, tc.err.Type)
assert.Equals(t, k.Detail, tc.err.Detail)
assert.Equals(t, k.Status, tc.err.Status)
assert.Equals(t, k.Err.Error(), tc.err.Err.Error())
assert.Equals(t, k.Detail, tc.err.Detail)
default:
assert.FatalError(t, errors.New("unexpected error type"))
}
}
} else {
assert.Nil(t, tc.err)
}
})
}
}
/* /*
var certDuration = 6 * time.Hour var certDuration = 6 * time.Hour