464 lines
13 KiB
Go
464 lines
13 KiB
Go
package nosql
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/smallstep/assert"
|
|
"github.com/smallstep/certificates/acme"
|
|
"github.com/smallstep/certificates/db"
|
|
"github.com/smallstep/nosql"
|
|
nosqldb "github.com/smallstep/nosql/database"
|
|
)
|
|
|
|
func TestDB_getDBChallenge(t *testing.T) {
|
|
chID := "chID"
|
|
type test struct {
|
|
db nosql.DB
|
|
err error
|
|
acmeErr *acme.Error
|
|
dbc *dbChallenge
|
|
}
|
|
var tests = map[string]func(t *testing.T) test{
|
|
"fail/not-found": func(t *testing.T) test {
|
|
return test{
|
|
db: &db.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), chID)
|
|
|
|
return nil, nosqldb.ErrNotFound
|
|
},
|
|
},
|
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
|
|
}
|
|
},
|
|
"fail/db.Get-error": func(t *testing.T) test {
|
|
return test{
|
|
db: &db.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), chID)
|
|
|
|
return nil, errors.New("force")
|
|
},
|
|
},
|
|
err: errors.New("error loading acme challenge chID: force"),
|
|
}
|
|
},
|
|
"fail/unmarshal-error": func(t *testing.T) test {
|
|
return test{
|
|
db: &db.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), chID)
|
|
|
|
return []byte("foo"), nil
|
|
},
|
|
},
|
|
err: errors.New("error unmarshaling dbChallenge"),
|
|
}
|
|
},
|
|
"ok": func(t *testing.T) test {
|
|
dbc := &dbChallenge{
|
|
ID: chID,
|
|
AccountID: "accountID",
|
|
Type: "dns-01",
|
|
Status: acme.StatusPending,
|
|
Token: "token",
|
|
Value: "test.ca.smallstep.com",
|
|
CreatedAt: clock.Now(),
|
|
ValidatedAt: "foobar",
|
|
Error: acme.NewErrorISE("force"),
|
|
}
|
|
b, err := json.Marshal(dbc)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
db: &db.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), chID)
|
|
|
|
return b, nil
|
|
},
|
|
},
|
|
dbc: dbc,
|
|
}
|
|
},
|
|
}
|
|
for name, run := range tests {
|
|
tc := run(t)
|
|
t.Run(name, func(t *testing.T) {
|
|
db := DB{db: tc.db}
|
|
if ch, err := db.getDBChallenge(context.Background(), chID); err != nil {
|
|
switch k := err.(type) {
|
|
case *acme.Error:
|
|
if assert.NotNil(t, tc.acmeErr) {
|
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
|
}
|
|
default:
|
|
if assert.NotNil(t, tc.err) {
|
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
}
|
|
}
|
|
} else {
|
|
if assert.Nil(t, tc.err) {
|
|
assert.Equals(t, ch.ID, tc.dbc.ID)
|
|
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
|
assert.Equals(t, ch.Type, tc.dbc.Type)
|
|
assert.Equals(t, ch.Status, tc.dbc.Status)
|
|
assert.Equals(t, ch.Token, tc.dbc.Token)
|
|
assert.Equals(t, ch.Value, tc.dbc.Value)
|
|
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
|
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDB_CreateChallenge(t *testing.T) {
|
|
type test struct {
|
|
db nosql.DB
|
|
ch *acme.Challenge
|
|
err error
|
|
_id *string
|
|
}
|
|
var tests = map[string]func(t *testing.T) test{
|
|
"fail/cmpAndSwap-error": func(t *testing.T) test {
|
|
ch := &acme.Challenge{
|
|
AccountID: "accountID",
|
|
Type: "dns-01",
|
|
Status: acme.StatusPending,
|
|
Token: "token",
|
|
Value: "test.ca.smallstep.com",
|
|
}
|
|
return test{
|
|
db: &db.MockNoSQLDB{
|
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), ch.ID)
|
|
assert.Equals(t, old, nil)
|
|
|
|
dbc := new(dbChallenge)
|
|
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
|
assert.Equals(t, dbc.ID, string(key))
|
|
assert.Equals(t, dbc.AccountID, ch.AccountID)
|
|
assert.Equals(t, dbc.Type, ch.Type)
|
|
assert.Equals(t, dbc.Status, ch.Status)
|
|
assert.Equals(t, dbc.Token, ch.Token)
|
|
assert.Equals(t, dbc.Value, ch.Value)
|
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
|
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
|
return nil, false, errors.New("force")
|
|
},
|
|
},
|
|
ch: ch,
|
|
err: errors.New("error saving acme challenge: force"),
|
|
}
|
|
},
|
|
"ok": func(t *testing.T) test {
|
|
var (
|
|
id string
|
|
idPtr = &id
|
|
ch = &acme.Challenge{
|
|
AccountID: "accountID",
|
|
Type: "dns-01",
|
|
Status: acme.StatusPending,
|
|
Token: "token",
|
|
Value: "test.ca.smallstep.com",
|
|
}
|
|
)
|
|
|
|
return test{
|
|
ch: ch,
|
|
db: &db.MockNoSQLDB{
|
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
|
*idPtr = string(key)
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), ch.ID)
|
|
assert.Equals(t, old, nil)
|
|
|
|
dbc := new(dbChallenge)
|
|
assert.FatalError(t, json.Unmarshal(nu, dbc))
|
|
assert.Equals(t, dbc.ID, string(key))
|
|
assert.Equals(t, dbc.AccountID, ch.AccountID)
|
|
assert.Equals(t, dbc.Type, ch.Type)
|
|
assert.Equals(t, dbc.Status, ch.Status)
|
|
assert.Equals(t, dbc.Token, ch.Token)
|
|
assert.Equals(t, dbc.Value, ch.Value)
|
|
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
|
|
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
|
|
return nil, true, nil
|
|
},
|
|
},
|
|
_id: idPtr,
|
|
}
|
|
},
|
|
}
|
|
for name, run := range tests {
|
|
tc := run(t)
|
|
t.Run(name, func(t *testing.T) {
|
|
db := DB{db: tc.db}
|
|
if err := db.CreateChallenge(context.Background(), tc.ch); err != nil {
|
|
if assert.NotNil(t, tc.err) {
|
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
}
|
|
} else {
|
|
if assert.Nil(t, tc.err) {
|
|
assert.Equals(t, tc.ch.ID, *tc._id)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDB_GetChallenge(t *testing.T) {
|
|
chID := "chID"
|
|
azID := "azID"
|
|
type test struct {
|
|
db nosql.DB
|
|
err error
|
|
acmeErr *acme.Error
|
|
dbc *dbChallenge
|
|
}
|
|
var tests = map[string]func(t *testing.T) test{
|
|
"fail/db.Get-error": func(t *testing.T) test {
|
|
return test{
|
|
db: &db.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), chID)
|
|
|
|
return nil, errors.New("force")
|
|
},
|
|
},
|
|
err: errors.New("error loading acme challenge chID: force"),
|
|
}
|
|
},
|
|
"fail/forward-acme-error": func(t *testing.T) test {
|
|
return test{
|
|
db: &db.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), chID)
|
|
|
|
return nil, nosqldb.ErrNotFound
|
|
},
|
|
},
|
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
|
|
}
|
|
},
|
|
"ok": func(t *testing.T) test {
|
|
dbc := &dbChallenge{
|
|
ID: chID,
|
|
AccountID: "accountID",
|
|
Type: "dns-01",
|
|
Status: acme.StatusPending,
|
|
Token: "token",
|
|
Value: "test.ca.smallstep.com",
|
|
CreatedAt: clock.Now(),
|
|
ValidatedAt: "foobar",
|
|
Error: acme.NewErrorISE("force"),
|
|
}
|
|
b, err := json.Marshal(dbc)
|
|
assert.FatalError(t, err)
|
|
return test{
|
|
db: &db.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), chID)
|
|
|
|
return b, nil
|
|
},
|
|
},
|
|
dbc: dbc,
|
|
}
|
|
},
|
|
}
|
|
for name, run := range tests {
|
|
tc := run(t)
|
|
t.Run(name, func(t *testing.T) {
|
|
db := DB{db: tc.db}
|
|
if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil {
|
|
switch k := err.(type) {
|
|
case *acme.Error:
|
|
if assert.NotNil(t, tc.acmeErr) {
|
|
assert.Equals(t, k.Type, tc.acmeErr.Type)
|
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
|
assert.Equals(t, k.Status, tc.acmeErr.Status)
|
|
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
|
|
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
|
|
}
|
|
default:
|
|
if assert.NotNil(t, tc.err) {
|
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
}
|
|
}
|
|
} else {
|
|
if assert.Nil(t, tc.err) {
|
|
assert.Equals(t, ch.ID, tc.dbc.ID)
|
|
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
|
assert.Equals(t, ch.Type, tc.dbc.Type)
|
|
assert.Equals(t, ch.Status, tc.dbc.Status)
|
|
assert.Equals(t, ch.Token, tc.dbc.Token)
|
|
assert.Equals(t, ch.Value, tc.dbc.Value)
|
|
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
|
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDB_UpdateChallenge(t *testing.T) {
|
|
chID := "chID"
|
|
dbc := &dbChallenge{
|
|
ID: chID,
|
|
AccountID: "accountID",
|
|
Type: "dns-01",
|
|
Status: acme.StatusPending,
|
|
Token: "token",
|
|
Value: "test.ca.smallstep.com",
|
|
CreatedAt: clock.Now(),
|
|
}
|
|
b, err := json.Marshal(dbc)
|
|
assert.FatalError(t, err)
|
|
type test struct {
|
|
db nosql.DB
|
|
ch *acme.Challenge
|
|
err error
|
|
}
|
|
var tests = map[string]func(t *testing.T) test{
|
|
"fail/db.Get-error": func(t *testing.T) test {
|
|
return test{
|
|
ch: &acme.Challenge{
|
|
ID: chID,
|
|
},
|
|
db: &db.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), chID)
|
|
|
|
return nil, errors.New("force")
|
|
},
|
|
},
|
|
err: errors.New("error loading acme challenge chID: force"),
|
|
}
|
|
},
|
|
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
|
|
updCh := &acme.Challenge{
|
|
ID: chID,
|
|
Status: acme.StatusValid,
|
|
ValidatedAt: "foobar",
|
|
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
|
}
|
|
return test{
|
|
ch: updCh,
|
|
db: &db.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), chID)
|
|
|
|
return b, nil
|
|
},
|
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, old, b)
|
|
|
|
dbOld := new(dbChallenge)
|
|
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
|
assert.Equals(t, dbc, dbOld)
|
|
|
|
dbNew := new(dbChallenge)
|
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
|
assert.Equals(t, dbNew.ID, dbc.ID)
|
|
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
|
|
assert.Equals(t, dbNew.Type, dbc.Type)
|
|
assert.Equals(t, dbNew.Status, updCh.Status)
|
|
assert.Equals(t, dbNew.Token, dbc.Token)
|
|
assert.Equals(t, dbNew.Value, dbc.Value)
|
|
assert.Equals(t, dbNew.Error.Error(), updCh.Error.Error())
|
|
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
|
|
assert.Equals(t, dbNew.ValidatedAt, updCh.ValidatedAt)
|
|
return nil, false, errors.New("force")
|
|
},
|
|
},
|
|
err: errors.New("error saving acme challenge: force"),
|
|
}
|
|
},
|
|
"ok": func(t *testing.T) test {
|
|
updCh := &acme.Challenge{
|
|
ID: dbc.ID,
|
|
AccountID: dbc.AccountID,
|
|
Type: dbc.Type,
|
|
Token: dbc.Token,
|
|
Value: dbc.Value,
|
|
Status: acme.StatusValid,
|
|
ValidatedAt: "foobar",
|
|
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
|
|
}
|
|
return test{
|
|
ch: updCh,
|
|
db: &db.MockNoSQLDB{
|
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, string(key), chID)
|
|
|
|
return b, nil
|
|
},
|
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
|
assert.Equals(t, bucket, challengeTable)
|
|
assert.Equals(t, old, b)
|
|
|
|
dbOld := new(dbChallenge)
|
|
assert.FatalError(t, json.Unmarshal(old, dbOld))
|
|
assert.Equals(t, dbc, dbOld)
|
|
|
|
dbNew := new(dbChallenge)
|
|
assert.FatalError(t, json.Unmarshal(nu, dbNew))
|
|
assert.Equals(t, dbNew.ID, dbc.ID)
|
|
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
|
|
assert.Equals(t, dbNew.Type, dbc.Type)
|
|
assert.Equals(t, dbNew.Token, dbc.Token)
|
|
assert.Equals(t, dbNew.Value, dbc.Value)
|
|
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
|
|
assert.Equals(t, dbNew.Status, acme.StatusValid)
|
|
assert.Equals(t, dbNew.ValidatedAt, "foobar")
|
|
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
|
return nu, true, nil
|
|
},
|
|
},
|
|
}
|
|
},
|
|
}
|
|
for name, run := range tests {
|
|
tc := run(t)
|
|
t.Run(name, func(t *testing.T) {
|
|
db := DB{db: tc.db}
|
|
if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil {
|
|
if assert.NotNil(t, tc.err) {
|
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
|
}
|
|
} else {
|
|
if assert.Nil(t, tc.err) {
|
|
assert.Equals(t, tc.ch.ID, dbc.ID)
|
|
assert.Equals(t, tc.ch.AccountID, dbc.AccountID)
|
|
assert.Equals(t, tc.ch.Type, dbc.Type)
|
|
assert.Equals(t, tc.ch.Token, dbc.Token)
|
|
assert.Equals(t, tc.ch.Value, dbc.Value)
|
|
assert.Equals(t, tc.ch.ValidatedAt, "foobar")
|
|
assert.Equals(t, tc.ch.Status, acme.StatusValid)
|
|
assert.Equals(t, tc.ch.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|