810 lines
21 KiB
Go
810 lines
21 KiB
Go
|
package acme
|
||
|
|
||
|
import (
|
||
|
"encoding/json"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/pkg/errors"
|
||
|
"github.com/smallstep/assert"
|
||
|
"github.com/smallstep/certificates/db"
|
||
|
"github.com/smallstep/nosql"
|
||
|
"github.com/smallstep/nosql/database"
|
||
|
)
|
||
|
|
||
|
func newAz() (authz, error) {
|
||
|
mockdb := &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
return []byte("foo"), true, nil
|
||
|
},
|
||
|
}
|
||
|
return newAuthz(mockdb, "1234", Identifier{
|
||
|
Type: "dns", Value: "acme.example.com",
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestGetAuthz(t *testing.T) {
|
||
|
type test struct {
|
||
|
id string
|
||
|
db nosql.DB
|
||
|
az authz
|
||
|
err *Error
|
||
|
}
|
||
|
tests := map[string]func(t *testing.T) test{
|
||
|
"fail/not-found": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
return test{
|
||
|
az: az,
|
||
|
id: az.getID(),
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||
|
return nil, database.ErrNotFound
|
||
|
},
|
||
|
},
|
||
|
err: MalformedErr(errors.Errorf("authz %s not found: not found", az.getID())),
|
||
|
}
|
||
|
},
|
||
|
"fail/db-error": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
return test{
|
||
|
az: az,
|
||
|
id: az.getID(),
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||
|
return nil, errors.New("force")
|
||
|
},
|
||
|
},
|
||
|
err: ServerInternalErr(errors.Errorf("error loading authz %s: force", az.getID())),
|
||
|
}
|
||
|
},
|
||
|
"fail/unmarshal-error": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
_az, ok := az.(*dnsAuthz)
|
||
|
assert.Fatal(t, ok)
|
||
|
_az.baseAuthz.Identifier.Type = "foo"
|
||
|
b, err := json.Marshal(az)
|
||
|
assert.FatalError(t, err)
|
||
|
return test{
|
||
|
az: az,
|
||
|
id: az.getID(),
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||
|
assert.Equals(t, bucket, authzTable)
|
||
|
assert.Equals(t, key, []byte(az.getID()))
|
||
|
return b, nil
|
||
|
},
|
||
|
},
|
||
|
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||
|
}
|
||
|
},
|
||
|
"ok": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
b, err := json.Marshal(az)
|
||
|
assert.FatalError(t, err)
|
||
|
return test{
|
||
|
az: az,
|
||
|
id: az.getID(),
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||
|
assert.Equals(t, bucket, authzTable)
|
||
|
assert.Equals(t, key, []byte(az.getID()))
|
||
|
return b, nil
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
},
|
||
|
}
|
||
|
for name, run := range tests {
|
||
|
t.Run(name, func(t *testing.T) {
|
||
|
tc := run(t)
|
||
|
if az, err := getAuthz(tc.db, tc.id); err != nil {
|
||
|
if assert.NotNil(t, tc.err) {
|
||
|
ae, ok := err.(*Error)
|
||
|
assert.True(t, ok)
|
||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||
|
}
|
||
|
} else {
|
||
|
if assert.Nil(t, tc.err) {
|
||
|
assert.Equals(t, tc.az.getID(), az.getID())
|
||
|
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||
|
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||
|
assert.Equals(t, tc.az.getIdentifier(), az.getIdentifier())
|
||
|
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||
|
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||
|
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestAuthzClone(t *testing.T) {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
|
||
|
clone := az.clone()
|
||
|
|
||
|
assert.Equals(t, clone.getID(), az.getID())
|
||
|
assert.Equals(t, clone.getAccountID(), az.getAccountID())
|
||
|
assert.Equals(t, clone.getStatus(), az.getStatus())
|
||
|
assert.Equals(t, clone.getIdentifier(), az.getIdentifier())
|
||
|
assert.Equals(t, clone.getExpiry(), az.getExpiry())
|
||
|
assert.Equals(t, clone.getCreated(), az.getCreated())
|
||
|
assert.Equals(t, clone.getChallenges(), az.getChallenges())
|
||
|
|
||
|
clone.Status = StatusValid
|
||
|
|
||
|
assert.NotEquals(t, clone.getStatus(), az.getStatus())
|
||
|
}
|
||
|
|
||
|
func TestNewAuthz(t *testing.T) {
|
||
|
iden := Identifier{
|
||
|
Type: "dns", Value: "acme.example.com",
|
||
|
}
|
||
|
accID := "1234"
|
||
|
type test struct {
|
||
|
iden Identifier
|
||
|
db nosql.DB
|
||
|
err *Error
|
||
|
resChs *([]string)
|
||
|
}
|
||
|
tests := map[string]func(t *testing.T) test{
|
||
|
"fail/unexpected-type": func(t *testing.T) test {
|
||
|
return test{
|
||
|
iden: Identifier{Type: "foo", Value: "acme.example.com"},
|
||
|
err: MalformedErr(errors.New("unexpected authz type foo")),
|
||
|
}
|
||
|
},
|
||
|
"fail/new-http-chall-error": func(t *testing.T) test {
|
||
|
return test{
|
||
|
iden: iden,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
return nil, false, errors.New("force")
|
||
|
},
|
||
|
},
|
||
|
err: ServerInternalErr(errors.New("error creating http challenge: error saving acme challenge: force")),
|
||
|
}
|
||
|
},
|
||
|
"fail/new-dns-chall-error": func(t *testing.T) test {
|
||
|
count := 0
|
||
|
return test{
|
||
|
iden: iden,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
if count == 1 {
|
||
|
return nil, false, errors.New("force")
|
||
|
}
|
||
|
count++
|
||
|
return nil, true, nil
|
||
|
},
|
||
|
},
|
||
|
err: ServerInternalErr(errors.New("error creating dns challenge: error saving acme challenge: force")),
|
||
|
}
|
||
|
},
|
||
|
"fail/save-authz-error": func(t *testing.T) test {
|
||
|
count := 0
|
||
|
return test{
|
||
|
iden: iden,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
if count == 2 {
|
||
|
return nil, false, errors.New("force")
|
||
|
}
|
||
|
count++
|
||
|
return nil, true, nil
|
||
|
},
|
||
|
},
|
||
|
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||
|
}
|
||
|
},
|
||
|
"ok": func(t *testing.T) test {
|
||
|
chs := &([]string{})
|
||
|
count := 0
|
||
|
return test{
|
||
|
iden: iden,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
if count == 2 {
|
||
|
assert.Equals(t, bucket, authzTable)
|
||
|
assert.Equals(t, old, nil)
|
||
|
|
||
|
az, err := unmarshalAuthz(newval)
|
||
|
assert.FatalError(t, err)
|
||
|
|
||
|
assert.Equals(t, az.getID(), string(key))
|
||
|
assert.Equals(t, az.getAccountID(), accID)
|
||
|
assert.Equals(t, az.getStatus(), StatusPending)
|
||
|
assert.Equals(t, az.getIdentifier(), iden)
|
||
|
assert.Equals(t, az.getWildcard(), false)
|
||
|
|
||
|
*chs = az.getChallenges()
|
||
|
|
||
|
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||
|
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||
|
|
||
|
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||
|
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||
|
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||
|
}
|
||
|
count++
|
||
|
return nil, true, nil
|
||
|
},
|
||
|
},
|
||
|
resChs: chs,
|
||
|
}
|
||
|
},
|
||
|
"ok/wildcard": func(t *testing.T) test {
|
||
|
chs := &([]string{})
|
||
|
count := 0
|
||
|
_iden := Identifier{Type: "dns", Value: "*.acme.example.com"}
|
||
|
return test{
|
||
|
iden: _iden,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
if count == 1 {
|
||
|
assert.Equals(t, bucket, authzTable)
|
||
|
assert.Equals(t, old, nil)
|
||
|
|
||
|
az, err := unmarshalAuthz(newval)
|
||
|
assert.FatalError(t, err)
|
||
|
|
||
|
assert.Equals(t, az.getID(), string(key))
|
||
|
assert.Equals(t, az.getAccountID(), accID)
|
||
|
assert.Equals(t, az.getStatus(), StatusPending)
|
||
|
assert.Equals(t, az.getIdentifier(), iden)
|
||
|
assert.Equals(t, az.getWildcard(), true)
|
||
|
|
||
|
*chs = az.getChallenges()
|
||
|
// Verify that we only have 1 challenge instead of 2.
|
||
|
assert.True(t, len(*chs) == 1)
|
||
|
|
||
|
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||
|
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||
|
|
||
|
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||
|
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||
|
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||
|
}
|
||
|
count++
|
||
|
return nil, true, nil
|
||
|
},
|
||
|
},
|
||
|
resChs: chs,
|
||
|
}
|
||
|
},
|
||
|
}
|
||
|
for name, run := range tests {
|
||
|
tc := run(t)
|
||
|
t.Run(name, func(t *testing.T) {
|
||
|
az, err := newAuthz(tc.db, accID, tc.iden)
|
||
|
if err != nil {
|
||
|
if assert.NotNil(t, tc.err) {
|
||
|
ae, ok := err.(*Error)
|
||
|
assert.True(t, ok)
|
||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||
|
}
|
||
|
} else {
|
||
|
if assert.Nil(t, tc.err) {
|
||
|
assert.Equals(t, az.getAccountID(), accID)
|
||
|
assert.Equals(t, az.getType(), "dns")
|
||
|
assert.Equals(t, az.getStatus(), StatusPending)
|
||
|
|
||
|
assert.True(t, az.getCreated().Before(time.Now().UTC().Add(time.Minute)))
|
||
|
assert.True(t, az.getCreated().After(time.Now().UTC().Add(-1*time.Minute)))
|
||
|
|
||
|
expiry := az.getCreated().Add(defaultExpiryDuration)
|
||
|
assert.True(t, az.getExpiry().Before(expiry.Add(time.Minute)))
|
||
|
assert.True(t, az.getExpiry().After(expiry.Add(-1*time.Minute)))
|
||
|
|
||
|
assert.Equals(t, az.getChallenges(), *(tc.resChs))
|
||
|
|
||
|
if strings.HasPrefix(tc.iden.Value, "*.") {
|
||
|
assert.True(t, az.getWildcard())
|
||
|
assert.Equals(t, az.getIdentifier().Value, strings.TrimPrefix(tc.iden.Value, "*."))
|
||
|
} else {
|
||
|
assert.False(t, az.getWildcard())
|
||
|
assert.Equals(t, az.getIdentifier().Value, tc.iden.Value)
|
||
|
}
|
||
|
|
||
|
assert.True(t, az.getID() != "")
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestAuthzToACME(t *testing.T) {
|
||
|
dir := newDirectory("ca.smallstep.com", "acme")
|
||
|
|
||
|
var (
|
||
|
ch1, ch2 challenge
|
||
|
ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||
|
err error
|
||
|
)
|
||
|
|
||
|
count := 0
|
||
|
mockdb := &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
if count == 0 {
|
||
|
*ch1Bytes = newval
|
||
|
ch1, err = unmarshalChallenge(newval)
|
||
|
assert.FatalError(t, err)
|
||
|
} else if count == 1 {
|
||
|
*ch2Bytes = newval
|
||
|
ch2, err = unmarshalChallenge(newval)
|
||
|
assert.FatalError(t, err)
|
||
|
}
|
||
|
count++
|
||
|
return []byte("foo"), true, nil
|
||
|
},
|
||
|
}
|
||
|
iden := Identifier{
|
||
|
Type: "dns", Value: "acme.example.com",
|
||
|
}
|
||
|
az, err := newAuthz(mockdb, "1234", iden)
|
||
|
assert.FatalError(t, err)
|
||
|
prov := newProv()
|
||
|
|
||
|
type test struct {
|
||
|
db nosql.DB
|
||
|
err *Error
|
||
|
}
|
||
|
tests := map[string]func(t *testing.T) test{
|
||
|
"fail/getChallenge1-error": func(t *testing.T) test {
|
||
|
return test{
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||
|
return nil, errors.New("force")
|
||
|
},
|
||
|
},
|
||
|
err: ServerInternalErr(errors.New("error loading challenge")),
|
||
|
}
|
||
|
},
|
||
|
"fail/getChallenge2-error": func(t *testing.T) test {
|
||
|
count := 0
|
||
|
return test{
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||
|
if count == 1 {
|
||
|
return nil, errors.New("force")
|
||
|
}
|
||
|
count++
|
||
|
return *ch1Bytes, nil
|
||
|
},
|
||
|
},
|
||
|
err: ServerInternalErr(errors.New("error loading challenge")),
|
||
|
}
|
||
|
},
|
||
|
"ok": func(t *testing.T) test {
|
||
|
count := 0
|
||
|
return test{
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||
|
if count == 0 {
|
||
|
count++
|
||
|
return *ch1Bytes, nil
|
||
|
}
|
||
|
return *ch2Bytes, nil
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
},
|
||
|
}
|
||
|
for name, run := range tests {
|
||
|
tc := run(t)
|
||
|
t.Run(name, func(t *testing.T) {
|
||
|
acmeAz, err := az.toACME(tc.db, dir, prov)
|
||
|
if err != nil {
|
||
|
if assert.NotNil(t, tc.err) {
|
||
|
ae, ok := err.(*Error)
|
||
|
assert.True(t, ok)
|
||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||
|
}
|
||
|
} else {
|
||
|
if assert.Nil(t, tc.err) {
|
||
|
assert.Equals(t, acmeAz.ID, az.getID())
|
||
|
assert.Equals(t, acmeAz.Identifier, iden)
|
||
|
assert.Equals(t, acmeAz.Status, StatusPending)
|
||
|
|
||
|
acmeCh1, err := ch1.toACME(nil, dir, prov)
|
||
|
assert.FatalError(t, err)
|
||
|
acmeCh2, err := ch2.toACME(nil, dir, prov)
|
||
|
assert.FatalError(t, err)
|
||
|
|
||
|
assert.Equals(t, acmeAz.Challenges[0], acmeCh1)
|
||
|
assert.Equals(t, acmeAz.Challenges[1], acmeCh2)
|
||
|
|
||
|
expiry, err := time.Parse(time.RFC3339, acmeAz.Expires)
|
||
|
assert.FatalError(t, err)
|
||
|
assert.Equals(t, expiry.String(), az.getExpiry().String())
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestAuthzSave(t *testing.T) {
|
||
|
type test struct {
|
||
|
az, old authz
|
||
|
db nosql.DB
|
||
|
err *Error
|
||
|
}
|
||
|
tests := map[string]func(t *testing.T) test{
|
||
|
"fail/old-nil/swap-error": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
return test{
|
||
|
az: az,
|
||
|
old: nil,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
return nil, false, errors.New("force")
|
||
|
},
|
||
|
},
|
||
|
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||
|
}
|
||
|
},
|
||
|
"fail/old-nil/swap-false": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
return test{
|
||
|
az: az,
|
||
|
old: nil,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
return []byte("foo"), false, nil
|
||
|
},
|
||
|
},
|
||
|
err: ServerInternalErr(errors.New("error storing authz; value has changed since last read")),
|
||
|
}
|
||
|
},
|
||
|
"ok/old-nil": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
b, err := json.Marshal(az)
|
||
|
assert.FatalError(t, err)
|
||
|
return test{
|
||
|
az: az,
|
||
|
old: nil,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
assert.Equals(t, old, nil)
|
||
|
assert.Equals(t, b, newval)
|
||
|
assert.Equals(t, bucket, authzTable)
|
||
|
assert.Equals(t, []byte(az.getID()), key)
|
||
|
return nil, true, nil
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
},
|
||
|
"ok/old-not-nil": func(t *testing.T) test {
|
||
|
oldAz, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
|
||
|
oldb, err := json.Marshal(oldAz)
|
||
|
assert.FatalError(t, err)
|
||
|
b, err := json.Marshal(az)
|
||
|
assert.FatalError(t, err)
|
||
|
return test{
|
||
|
az: az,
|
||
|
old: oldAz,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
assert.Equals(t, old, oldb)
|
||
|
assert.Equals(t, b, newval)
|
||
|
assert.Equals(t, bucket, authzTable)
|
||
|
assert.Equals(t, []byte(az.getID()), key)
|
||
|
return []byte("foo"), true, nil
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
},
|
||
|
}
|
||
|
for name, run := range tests {
|
||
|
t.Run(name, func(t *testing.T) {
|
||
|
tc := run(t)
|
||
|
if err := tc.az.save(tc.db, tc.old); err != nil {
|
||
|
if assert.NotNil(t, tc.err) {
|
||
|
ae, ok := err.(*Error)
|
||
|
assert.True(t, ok)
|
||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||
|
}
|
||
|
} else {
|
||
|
assert.Nil(t, tc.err)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestAuthzUnmarshal(t *testing.T) {
|
||
|
type test struct {
|
||
|
az authz
|
||
|
azb []byte
|
||
|
err *Error
|
||
|
}
|
||
|
tests := map[string]func(t *testing.T) test{
|
||
|
"fail/nil": func(t *testing.T) test {
|
||
|
return test{
|
||
|
azb: nil,
|
||
|
err: ServerInternalErr(errors.New("error unmarshaling authz type: unexpected end of JSON input")),
|
||
|
}
|
||
|
},
|
||
|
"fail/unexpected-type": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
_az, ok := az.(*dnsAuthz)
|
||
|
assert.Fatal(t, ok)
|
||
|
_az.baseAuthz.Identifier.Type = "foo"
|
||
|
b, err := json.Marshal(az)
|
||
|
assert.FatalError(t, err)
|
||
|
return test{
|
||
|
azb: b,
|
||
|
err: ServerInternalErr(errors.New("unexpected authz type foo")),
|
||
|
}
|
||
|
},
|
||
|
"ok/dns": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
b, err := json.Marshal(az)
|
||
|
assert.FatalError(t, err)
|
||
|
return test{
|
||
|
az: az,
|
||
|
azb: b,
|
||
|
}
|
||
|
},
|
||
|
}
|
||
|
for name, run := range tests {
|
||
|
t.Run(name, func(t *testing.T) {
|
||
|
tc := run(t)
|
||
|
if az, err := unmarshalAuthz(tc.azb); err != nil {
|
||
|
if assert.NotNil(t, tc.err) {
|
||
|
ae, ok := err.(*Error)
|
||
|
assert.True(t, ok)
|
||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||
|
}
|
||
|
} else {
|
||
|
if assert.Nil(t, tc.err) {
|
||
|
assert.Equals(t, tc.az.getID(), az.getID())
|
||
|
assert.Equals(t, tc.az.getAccountID(), az.getAccountID())
|
||
|
assert.Equals(t, tc.az.getStatus(), az.getStatus())
|
||
|
assert.Equals(t, tc.az.getCreated(), az.getCreated())
|
||
|
assert.Equals(t, tc.az.getExpiry(), az.getExpiry())
|
||
|
assert.Equals(t, tc.az.getWildcard(), az.getWildcard())
|
||
|
assert.Equals(t, tc.az.getChallenges(), az.getChallenges())
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestAuthzUpdateStatus(t *testing.T) {
|
||
|
type test struct {
|
||
|
az, res authz
|
||
|
err *Error
|
||
|
db nosql.DB
|
||
|
}
|
||
|
tests := map[string]func(t *testing.T) test{
|
||
|
"fail/already-invalid": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
_az, ok := az.(*dnsAuthz)
|
||
|
assert.Fatal(t, ok)
|
||
|
_az.baseAuthz.Status = StatusInvalid
|
||
|
return test{
|
||
|
az: az,
|
||
|
res: az,
|
||
|
}
|
||
|
},
|
||
|
"fail/already-valid": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
_az, ok := az.(*dnsAuthz)
|
||
|
assert.Fatal(t, ok)
|
||
|
_az.baseAuthz.Status = StatusValid
|
||
|
return test{
|
||
|
az: az,
|
||
|
res: az,
|
||
|
}
|
||
|
},
|
||
|
"fail/unexpected-status": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
_az, ok := az.(*dnsAuthz)
|
||
|
assert.Fatal(t, ok)
|
||
|
_az.baseAuthz.Status = StatusReady
|
||
|
return test{
|
||
|
az: az,
|
||
|
res: az,
|
||
|
err: ServerInternalErr(errors.New("unrecognized authz status: ready")),
|
||
|
}
|
||
|
},
|
||
|
"fail/save-error": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
_az, ok := az.(*dnsAuthz)
|
||
|
assert.Fatal(t, ok)
|
||
|
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||
|
return test{
|
||
|
az: az,
|
||
|
res: az,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
return nil, false, errors.New("force")
|
||
|
},
|
||
|
},
|
||
|
err: ServerInternalErr(errors.New("error storing authz: force")),
|
||
|
}
|
||
|
},
|
||
|
"ok/expired": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
_az, ok := az.(*dnsAuthz)
|
||
|
assert.Fatal(t, ok)
|
||
|
_az.baseAuthz.Expires = time.Now().UTC().Add(-time.Minute)
|
||
|
|
||
|
clone := az.clone()
|
||
|
clone.Error = MalformedErr(errors.New("authz has expired"))
|
||
|
clone.Status = StatusInvalid
|
||
|
return test{
|
||
|
az: az,
|
||
|
res: clone.parent(),
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
return nil, true, nil
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
},
|
||
|
"fail/get-challenge-error": func(t *testing.T) test {
|
||
|
az, err := newAz()
|
||
|
assert.FatalError(t, err)
|
||
|
|
||
|
return test{
|
||
|
az: az,
|
||
|
res: az,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||
|
return nil, errors.New("force")
|
||
|
},
|
||
|
},
|
||
|
err: ServerInternalErr(errors.New("error loading challenge")),
|
||
|
}
|
||
|
},
|
||
|
"ok/valid": func(t *testing.T) test {
|
||
|
var (
|
||
|
ch2 challenge
|
||
|
ch1Bytes = &([]byte{})
|
||
|
err error
|
||
|
)
|
||
|
|
||
|
count := 0
|
||
|
mockdb := &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
if count == 0 {
|
||
|
*ch1Bytes = newval
|
||
|
} else if count == 1 {
|
||
|
ch2, err = unmarshalChallenge(newval)
|
||
|
assert.FatalError(t, err)
|
||
|
}
|
||
|
count++
|
||
|
return nil, true, nil
|
||
|
},
|
||
|
}
|
||
|
iden := Identifier{
|
||
|
Type: "dns", Value: "acme.example.com",
|
||
|
}
|
||
|
az, err := newAuthz(mockdb, "1234", iden)
|
||
|
assert.FatalError(t, err)
|
||
|
_az, ok := az.(*dnsAuthz)
|
||
|
assert.Fatal(t, ok)
|
||
|
_az.baseAuthz.Error = MalformedErr(nil)
|
||
|
|
||
|
_ch, ok := ch2.(*dns01Challenge)
|
||
|
assert.Fatal(t, ok)
|
||
|
_ch.baseChallenge.Status = StatusValid
|
||
|
chb, err := json.Marshal(ch2)
|
||
|
|
||
|
clone := az.clone()
|
||
|
clone.Status = StatusValid
|
||
|
clone.Error = nil
|
||
|
|
||
|
count = 0
|
||
|
return test{
|
||
|
az: az,
|
||
|
res: clone.parent(),
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||
|
if count == 0 {
|
||
|
count++
|
||
|
return *ch1Bytes, nil
|
||
|
}
|
||
|
count++
|
||
|
return chb, nil
|
||
|
},
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
return nil, true, nil
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
},
|
||
|
"ok/still-pending": func(t *testing.T) test {
|
||
|
var ch1Bytes, ch2Bytes = &([]byte{}), &([]byte{})
|
||
|
|
||
|
count := 0
|
||
|
mockdb := &db.MockNoSQLDB{
|
||
|
MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
|
||
|
if count == 0 {
|
||
|
*ch1Bytes = newval
|
||
|
} else if count == 1 {
|
||
|
*ch2Bytes = newval
|
||
|
}
|
||
|
count++
|
||
|
return nil, true, nil
|
||
|
},
|
||
|
}
|
||
|
iden := Identifier{
|
||
|
Type: "dns", Value: "acme.example.com",
|
||
|
}
|
||
|
az, err := newAuthz(mockdb, "1234", iden)
|
||
|
assert.FatalError(t, err)
|
||
|
|
||
|
count = 0
|
||
|
return test{
|
||
|
az: az,
|
||
|
res: az,
|
||
|
db: &db.MockNoSQLDB{
|
||
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||
|
if count == 0 {
|
||
|
count++
|
||
|
return *ch1Bytes, nil
|
||
|
}
|
||
|
count++
|
||
|
return *ch2Bytes, nil
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
},
|
||
|
}
|
||
|
for name, run := range tests {
|
||
|
t.Run(name, func(t *testing.T) {
|
||
|
tc := run(t)
|
||
|
az, err := tc.az.updateStatus(tc.db)
|
||
|
if err != nil {
|
||
|
if assert.NotNil(t, tc.err) {
|
||
|
ae, ok := err.(*Error)
|
||
|
assert.True(t, ok)
|
||
|
assert.HasPrefix(t, ae.Error(), tc.err.Error())
|
||
|
assert.Equals(t, ae.StatusCode(), tc.err.StatusCode())
|
||
|
assert.Equals(t, ae.Type, tc.err.Type)
|
||
|
}
|
||
|
} else {
|
||
|
if assert.Nil(t, tc.err) {
|
||
|
expB, err := json.Marshal(tc.res)
|
||
|
assert.FatalError(t, err)
|
||
|
b, err := json.Marshal(az)
|
||
|
assert.FatalError(t, err)
|
||
|
assert.Equals(t, expB, b)
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|