[acme db interface] unit tests for challenge nosql db

This commit is contained in:
max furman 2021-03-18 23:08:13 -07:00
parent 4b1dda5bb6
commit 206909b12e
9 changed files with 789 additions and 97 deletions

View file

@ -10,7 +10,7 @@ import (
type Authorization struct { type Authorization struct {
Identifier Identifier `json:"identifier"` Identifier Identifier `json:"identifier"`
Status Status `json:"status"` Status Status `json:"status"`
Expires time.Time `json:"expires"` ExpiresAt time.Time `json:"expires"`
Challenges []*Challenge `json:"challenges"` Challenges []*Challenge `json:"challenges"`
Wildcard bool `json:"wildcard"` Wildcard bool `json:"wildcard"`
ID string `json:"-"` ID string `json:"-"`
@ -39,7 +39,7 @@ func (az *Authorization) UpdateStatus(ctx context.Context, db DB) error {
return nil return nil
case StatusPending: case StatusPending:
// check expiry // check expiry
if now.After(az.Expires) { if now.After(az.ExpiresAt) {
az.Status = StatusInvalid az.Status = StatusInvalid
break break
} }

View file

@ -25,7 +25,7 @@ type Challenge struct {
Type string `json:"type"` Type string `json:"type"`
Status Status `json:"status"` Status Status `json:"status"`
Token string `json:"token"` Token string `json:"token"`
Validated string `json:"validated,omitempty"` ValidatedAt string `json:"validated,omitempty"`
URL string `json:"url"` URL string `json:"url"`
Error *Error `json:"error,omitempty"` Error *Error `json:"error,omitempty"`
ID string `json:"-"` ID string `json:"-"`
@ -97,7 +97,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
// Update and store the challenge. // Update and store the challenge.
ch.Status = StatusValid ch.Status = StatusValid
ch.Error = nil ch.Error = nil
ch.Validated = clock.Now().Format(time.RFC3339) ch.ValidatedAt = clock.Now().Format(time.RFC3339)
if err = db.UpdateChallenge(ctx, ch); err != nil { if err = db.UpdateChallenge(ctx, ch); err != nil {
return WrapErrorISE(err, "error updating challenge") return WrapErrorISE(err, "error updating challenge")
@ -175,7 +175,7 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
ch.Status = StatusValid ch.Status = StatusValid
ch.Error = nil ch.Error = nil
ch.Validated = clock.Now().Format(time.RFC3339) ch.ValidatedAt = clock.Now().Format(time.RFC3339)
if err = db.UpdateChallenge(ctx, ch); err != nil { if err = db.UpdateChallenge(ctx, ch); err != nil {
return WrapErrorISE(err, "tlsalpn01ValidateChallenge - error updating challenge") return WrapErrorISE(err, "tlsalpn01ValidateChallenge - error updating challenge")
@ -231,7 +231,7 @@ func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebK
// Update and store the challenge. // Update and store the challenge.
ch.Status = StatusValid ch.Status = StatusValid
ch.Error = nil ch.Error = nil
ch.Validated = clock.Now().UTC().Format(time.RFC3339) ch.ValidatedAt = clock.Now().Format(time.RFC3339)
if err = db.UpdateChallenge(ctx, ch); err != nil { if err = db.UpdateChallenge(ctx, ch); err != nil {
return WrapErrorISE(err, "error updating challenge") return WrapErrorISE(err, "error updating challenge")

View file

@ -18,10 +18,10 @@ type dbAuthz struct {
AccountID string `json:"accountID"` AccountID string `json:"accountID"`
Identifier acme.Identifier `json:"identifier"` Identifier acme.Identifier `json:"identifier"`
Status acme.Status `json:"status"` Status acme.Status `json:"status"`
Expires time.Time `json:"expires"` ExpiresAt time.Time `json:"expiresAt"`
Challenges []string `json:"challenges"` Challenges []string `json:"challenges"`
Wildcard bool `json:"wildcard"` Wildcard bool `json:"wildcard"`
Created time.Time `json:"created"` CreatedAt time.Time `json:"createdAt"`
Error *acme.Error `json:"error"` Error *acme.Error `json:"error"`
} }
@ -66,7 +66,7 @@ func (db *DB) GetAuthorization(ctx context.Context, id string) (*acme.Authorizat
Status: dbaz.Status, Status: dbaz.Status,
Challenges: chs, Challenges: chs,
Wildcard: dbaz.Wildcard, Wildcard: dbaz.Wildcard,
Expires: dbaz.Expires, ExpiresAt: dbaz.ExpiresAt,
ID: dbaz.ID, ID: dbaz.ID,
}, nil }, nil
} }
@ -90,8 +90,8 @@ func (db *DB) CreateAuthorization(ctx context.Context, az *acme.Authorization) e
ID: az.ID, ID: az.ID,
AccountID: az.AccountID, AccountID: az.AccountID,
Status: acme.StatusPending, Status: acme.StatusPending,
Created: now, CreatedAt: now,
Expires: now.Add(defaultExpiryDuration), ExpiresAt: now.Add(defaultExpiryDuration),
Identifier: az.Identifier, Identifier: az.Identifier,
Challenges: chIDs, Challenges: chIDs,
Wildcard: az.Wildcard, Wildcard: az.Wildcard,

View file

@ -14,7 +14,7 @@ import (
type dbCert struct { type dbCert struct {
ID string `json:"id"` ID string `json:"id"`
Created time.Time `json:"created"` CreatedAt time.Time `json:"createdAt"`
AccountID string `json:"accountID"` AccountID string `json:"accountID"`
OrderID string `json:"orderID"` OrderID string `json:"orderID"`
Leaf []byte `json:"leaf"` Leaf []byte `json:"leaf"`
@ -47,7 +47,7 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) err
OrderID: cert.OrderID, OrderID: cert.OrderID,
Leaf: leaf, Leaf: leaf,
Intermediates: intermediates, Intermediates: intermediates,
Created: time.Now().UTC(), CreatedAt: time.Now().UTC(),
} }
return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable) return db.save(ctx, cert.ID, dbch, nil, "certificate", certTable)
} }
@ -57,7 +57,7 @@ func (db *DB) CreateCertificate(ctx context.Context, cert *acme.Certificate) err
func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) { func (db *DB) GetCertificate(ctx context.Context, id string) (*acme.Certificate, error) {
b, err := db.db.Get(certTable, []byte(id)) b, err := db.db.Get(certTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, errors.Wrapf(err, "certificate %s not found", id) return nil, acme.NewError(acme.ErrorMalformedType, "certificate %s not found", id)
} else if err != nil { } else if err != nil {
return nil, errors.Wrapf(err, "error loading certificate %s", id) return nil, errors.Wrapf(err, "error loading certificate %s", id)
} }

View file

@ -34,7 +34,7 @@ func TestDB_CreateCertificate(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/cmpAndSwap-error": func(t *testing.T) test { "fail/cmpAndSwap-error": func(t *testing.T) test {
cert := &acme.Certificate{ cert := &acme.Certificate{
AccountID: "accounttID", AccountID: "accountID",
OrderID: "orderID", OrderID: "orderID",
Leaf: leaf, Leaf: leaf,
Intermediates: []*x509.Certificate{inter, root}, Intermediates: []*x509.Certificate{inter, root},
@ -48,12 +48,11 @@ func TestDB_CreateCertificate(t *testing.T) {
dbc := new(dbCert) dbc := new(dbCert)
assert.FatalError(t, json.Unmarshal(nu, dbc)) assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.FatalError(t, err)
assert.Equals(t, dbc.ID, string(key)) assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.ID, cert.ID) assert.Equals(t, dbc.ID, cert.ID)
assert.Equals(t, dbc.AccountID, cert.AccountID) assert.Equals(t, dbc.AccountID, cert.AccountID)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.Created)) assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.Created)) assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
return nil, false, errors.New("force") return nil, false, errors.New("force")
}, },
}, },
@ -63,7 +62,7 @@ func TestDB_CreateCertificate(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
cert := &acme.Certificate{ cert := &acme.Certificate{
AccountID: "accounttID", AccountID: "accountID",
OrderID: "orderID", OrderID: "orderID",
Leaf: leaf, Leaf: leaf,
Intermediates: []*x509.Certificate{inter, root}, Intermediates: []*x509.Certificate{inter, root},
@ -83,12 +82,11 @@ func TestDB_CreateCertificate(t *testing.T) {
dbc := new(dbCert) dbc := new(dbCert)
assert.FatalError(t, json.Unmarshal(nu, dbc)) assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.FatalError(t, err)
assert.Equals(t, dbc.ID, string(key)) assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.ID, cert.ID) assert.Equals(t, dbc.ID, cert.ID)
assert.Equals(t, dbc.AccountID, cert.AccountID) assert.Equals(t, dbc.AccountID, cert.AccountID)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.Created)) assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.Created)) assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
return nil, true, nil return nil, true, nil
}, },
}, },
@ -126,6 +124,7 @@ func TestDB_GetCertificate(t *testing.T) {
type test struct { type test struct {
db nosql.DB db nosql.DB
err error err error
acmeErr *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test { "fail/not-found": func(t *testing.T) test {
@ -138,7 +137,7 @@ func TestDB_GetCertificate(t *testing.T) {
return nil, nosqldb.ErrNotFound return nil, nosqldb.ErrNotFound
}, },
}, },
err: errors.New("certificate certID not found"), acmeErr: acme.NewError(acme.ErrorMalformedType, "certificate certID not found"),
} }
}, },
"fail/db.Get-error": func(t *testing.T) test { "fail/db.Get-error": func(t *testing.T) test {
@ -182,7 +181,7 @@ func TestDB_GetCertificate(t *testing.T) {
Type: "Public Key", Type: "Public Key",
Bytes: leaf.Raw, Bytes: leaf.Raw,
}), }),
Created: clock.Now(), CreatedAt: clock.Now(),
} }
b, err := json.Marshal(cert) b, err := json.Marshal(cert)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -215,7 +214,7 @@ func TestDB_GetCertificate(t *testing.T) {
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: root.Raw, Bytes: root.Raw,
})...), })...),
Created: clock.Now(), CreatedAt: clock.Now(),
} }
b, err := json.Marshal(cert) b, err := json.Marshal(cert)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -232,9 +231,20 @@ func TestDB_GetCertificate(t *testing.T) {
db := DB{db: tc.db} db := DB{db: tc.db}
cert, err := db.GetCertificate(context.Background(), certID) cert, err := db.GetCertificate(context.Background(), certID)
if err != nil { if err != nil {
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
}
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, cert.ID, certID) assert.Equals(t, cert.ID, certID)

View file

@ -10,7 +10,6 @@ import (
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
) )
// dbChallenge is the base Challenge type that others build from.
type dbChallenge struct { type dbChallenge struct {
ID string `json:"id"` ID string `json:"id"`
AccountID string `json:"accountID"` AccountID string `json:"accountID"`
@ -19,8 +18,8 @@ type dbChallenge struct {
Status acme.Status `json:"status"` Status acme.Status `json:"status"`
Token string `json:"token"` Token string `json:"token"`
Value string `json:"value"` Value string `json:"value"`
Validated string `json:"validated"` ValidatedAt string `json:"validatedAt"`
Created time.Time `json:"created"` CreatedAt time.Time `json:"createdAt"`
Error *acme.Error `json:"error"` Error *acme.Error `json:"error"`
} }
@ -32,9 +31,9 @@ func (dbc *dbChallenge) clone() *dbChallenge {
func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) { func (db *DB) getDBChallenge(ctx context.Context, id string) (*dbChallenge, error) {
data, err := db.db.Get(challengeTable, []byte(id)) data, err := db.db.Get(challengeTable, []byte(id))
if nosql.IsErrNotFound(err) { if nosql.IsErrNotFound(err) {
return nil, errors.Wrapf(err, "challenge %s not found", id) return nil, acme.NewError(acme.ErrorMalformedType, "challenge %s not found", id)
} else if err != nil { } else if err != nil {
return nil, errors.Wrapf(err, "error loading challenge %s", id) return nil, errors.Wrapf(err, "error loading acme challenge %s", id)
} }
dbch := new(dbChallenge) dbch := new(dbChallenge)
@ -60,7 +59,7 @@ func (db *DB) CreateChallenge(ctx context.Context, ch *acme.Challenge) error {
Value: ch.Value, Value: ch.Value,
Status: acme.StatusPending, Status: acme.StatusPending,
Token: ch.Token, Token: ch.Token,
Created: clock.Now(), CreatedAt: clock.Now(),
Type: ch.Type, Type: ch.Type,
} }
@ -76,22 +75,21 @@ func (db *DB) GetChallenge(ctx context.Context, id, authzID string) (*acme.Chall
} }
ch := &acme.Challenge{ ch := &acme.Challenge{
ID: dbch.ID,
AccountID: dbch.AccountID,
AuthzID: dbch.AuthzID,
Type: dbch.Type, Type: dbch.Type,
Value: dbch.Value,
Status: dbch.Status, Status: dbch.Status,
Token: dbch.Token, Token: dbch.Token,
ID: dbch.ID,
AuthzID: dbch.AuthzID,
Error: dbch.Error, Error: dbch.Error,
Validated: dbch.Validated, ValidatedAt: dbch.ValidatedAt,
} }
return ch, nil return ch, nil
} }
// UpdateChallenge updates an ACME challenge type in the database. // UpdateChallenge updates an ACME challenge type in the database.
func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error { func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error {
if len(ch.ID) == 0 {
return errors.New("id cannot be empty")
}
old, err := db.getDBChallenge(ctx, ch.ID) old, err := db.getDBChallenge(ctx, ch.ID)
if err != nil { if err != nil {
return err return err
@ -99,10 +97,10 @@ func (db *DB) UpdateChallenge(ctx context.Context, ch *acme.Challenge) error {
nu := old.clone() nu := old.clone()
// These should be the only values chaning in an Update request. // These should be the only values changing in an Update request.
nu.Status = ch.Status nu.Status = ch.Status
nu.Error = ch.Error nu.Error = ch.Error
nu.Validated = ch.Validated nu.ValidatedAt = ch.ValidatedAt
return db.save(ctx, old.ID, nu, old, "challenge", challengeTable) return db.save(ctx, old.ID, nu, old, "challenge", challengeTable)
} }

View file

@ -0,0 +1,477 @@
package nosql
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
nosqldb "github.com/smallstep/nosql/database"
)
func TestDB_getDBChallenge(t *testing.T) {
chID := "chID"
type test struct {
db nosql.DB
err error
acmeErr *acme.Error
dbc *dbChallenge
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return nil, nosqldb.ErrNotFound
},
},
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
}
},
"fail/db.Get-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return nil, errors.New("force")
},
},
err: errors.New("error loading acme challenge chID: force"),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return []byte("foo"), nil
},
},
err: errors.New("error unmarshaling dbChallenge"),
}
},
"ok": func(t *testing.T) test {
dbc := &dbChallenge{
ID: chID,
AccountID: "accountID",
AuthzID: "authzID",
Type: "dns-01",
Status: acme.StatusPending,
Token: "token",
Value: "test.ca.smallstep.com",
CreatedAt: clock.Now(),
ValidatedAt: "foobar",
Error: acme.NewErrorISE("force"),
}
b, err := json.Marshal(dbc)
assert.FatalError(t, err)
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return b, nil
},
},
dbc: dbc,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if ch, err := db.getDBChallenge(context.Background(), chID); err != nil {
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err, tc.acmeErr.Err)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, ch.ID, tc.dbc.ID)
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
assert.Equals(t, ch.AuthzID, tc.dbc.AuthzID)
assert.Equals(t, ch.Type, tc.dbc.Type)
assert.Equals(t, ch.Status, tc.dbc.Status)
assert.Equals(t, ch.Token, tc.dbc.Token)
assert.Equals(t, ch.Value, tc.dbc.Value)
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
}
}
})
}
}
func TestDB_CreateChallenge(t *testing.T) {
type test struct {
db nosql.DB
ch *acme.Challenge
err error
_id *string
}
var tests = map[string]func(t *testing.T) test{
"fail/cmpAndSwap-error": func(t *testing.T) test {
ch := &acme.Challenge{
AccountID: "accountID",
AuthzID: "authzID",
Type: "dns-01",
Status: acme.StatusPending,
Token: "token",
Value: "test.ca.smallstep.com",
}
return test{
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), ch.ID)
assert.Equals(t, old, nil)
dbc := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.AccountID, ch.AccountID)
assert.Equals(t, dbc.AuthzID, ch.AuthzID)
assert.Equals(t, dbc.Type, ch.Type)
assert.Equals(t, dbc.Status, ch.Status)
assert.Equals(t, dbc.Token, ch.Token)
assert.Equals(t, dbc.Value, ch.Value)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
return nil, false, errors.New("force")
},
},
ch: ch,
err: errors.New("error saving acme challenge: force"),
}
},
"ok": func(t *testing.T) test {
var (
id string
idPtr = &id
ch = &acme.Challenge{
AccountID: "accountID",
AuthzID: "authzID",
Type: "dns-01",
Status: acme.StatusPending,
Token: "token",
Value: "test.ca.smallstep.com",
}
)
return test{
ch: ch,
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
*idPtr = string(key)
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), ch.ID)
assert.Equals(t, old, nil)
dbc := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(nu, dbc))
assert.Equals(t, dbc.ID, string(key))
assert.Equals(t, dbc.AccountID, ch.AccountID)
assert.Equals(t, dbc.AuthzID, ch.AuthzID)
assert.Equals(t, dbc.Type, ch.Type)
assert.Equals(t, dbc.Status, ch.Status)
assert.Equals(t, dbc.Token, ch.Token)
assert.Equals(t, dbc.Value, ch.Value)
assert.True(t, clock.Now().Add(-time.Minute).Before(dbc.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbc.CreatedAt))
return nil, true, nil
},
},
_id: idPtr,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if err := db.CreateChallenge(context.Background(), tc.ch); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.ch.ID, *tc._id)
}
}
})
}
}
func TestDB_GetChallenge(t *testing.T) {
chID := "chID"
azID := "azID"
type test struct {
db nosql.DB
err error
acmeErr *acme.Error
dbc *dbChallenge
}
var tests = map[string]func(t *testing.T) test{
"fail/db.Get-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return nil, errors.New("force")
},
},
err: errors.New("error loading acme challenge chID: force"),
}
},
"fail/forward-acme-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return nil, nosqldb.ErrNotFound
},
},
acmeErr: acme.NewError(acme.ErrorMalformedType, "challenge chID not found"),
}
},
"ok": func(t *testing.T) test {
dbc := &dbChallenge{
ID: chID,
AccountID: "accountID",
AuthzID: azID,
Type: "dns-01",
Status: acme.StatusPending,
Token: "token",
Value: "test.ca.smallstep.com",
CreatedAt: clock.Now(),
ValidatedAt: "foobar",
Error: acme.NewErrorISE("force"),
}
b, err := json.Marshal(dbc)
assert.FatalError(t, err)
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return b, nil
},
},
dbc: dbc,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil {
switch k := err.(type) {
case *acme.Error:
if assert.NotNil(t, tc.acmeErr) {
assert.Equals(t, k.Type, tc.acmeErr.Type)
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
assert.Equals(t, k.Status, tc.acmeErr.Status)
assert.Equals(t, k.Err.Error(), tc.acmeErr.Err.Error())
assert.Equals(t, k.Detail, tc.acmeErr.Detail)
}
default:
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, ch.ID, tc.dbc.ID)
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
assert.Equals(t, ch.AuthzID, tc.dbc.AuthzID)
assert.Equals(t, ch.Type, tc.dbc.Type)
assert.Equals(t, ch.Status, tc.dbc.Status)
assert.Equals(t, ch.Token, tc.dbc.Token)
assert.Equals(t, ch.Value, tc.dbc.Value)
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
}
}
})
}
}
func TestDB_UpdateChallenge(t *testing.T) {
chID := "chID"
dbc := &dbChallenge{
ID: chID,
AccountID: "accountID",
AuthzID: "azID",
Type: "dns-01",
Status: acme.StatusPending,
Token: "token",
Value: "test.ca.smallstep.com",
CreatedAt: clock.Now(),
}
b, err := json.Marshal(dbc)
assert.FatalError(t, err)
type test struct {
db nosql.DB
ch *acme.Challenge
err error
}
var tests = map[string]func(t *testing.T) test{
"fail/db.Get-error": func(t *testing.T) test {
return test{
ch: &acme.Challenge{
ID: chID,
},
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return nil, errors.New("force")
},
},
err: errors.New("error loading acme challenge chID: force"),
}
},
"fail/db.CmpAndSwap-error": func(t *testing.T) test {
updCh := &acme.Challenge{
ID: chID,
Status: acme.StatusValid,
ValidatedAt: "foobar",
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
}
return test{
ch: updCh,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return b, nil
},
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, old, b)
dbOld := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(old, dbOld))
assert.Equals(t, dbc, dbOld)
dbNew := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(nu, dbNew))
assert.Equals(t, dbNew.ID, dbc.ID)
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
assert.Equals(t, dbNew.AuthzID, dbc.AuthzID)
assert.Equals(t, dbNew.Type, dbc.Type)
assert.Equals(t, dbNew.Status, updCh.Status)
assert.Equals(t, dbNew.Token, dbc.Token)
assert.Equals(t, dbNew.Value, dbc.Value)
assert.Equals(t, dbNew.Error.Error(), updCh.Error.Error())
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
assert.Equals(t, dbNew.ValidatedAt, updCh.ValidatedAt)
return nil, false, errors.New("force")
},
},
err: errors.New("error saving acme challenge: force"),
}
},
"ok": func(t *testing.T) test {
updCh := &acme.Challenge{
ID: dbc.ID,
AccountID: dbc.AccountID,
AuthzID: dbc.AuthzID,
Type: dbc.Type,
Token: dbc.Token,
Value: dbc.Value,
Status: acme.StatusValid,
ValidatedAt: "foobar",
Error: acme.NewError(acme.ErrorMalformedType, "malformed"),
}
return test{
ch: updCh,
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, string(key), chID)
return b, nil
},
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, challengeTable)
assert.Equals(t, old, b)
dbOld := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(old, dbOld))
assert.Equals(t, dbc, dbOld)
dbNew := new(dbChallenge)
assert.FatalError(t, json.Unmarshal(nu, dbNew))
assert.Equals(t, dbNew.ID, dbc.ID)
assert.Equals(t, dbNew.AccountID, dbc.AccountID)
assert.Equals(t, dbNew.AuthzID, dbc.AuthzID)
assert.Equals(t, dbNew.Type, dbc.Type)
assert.Equals(t, dbNew.Token, dbc.Token)
assert.Equals(t, dbNew.Value, dbc.Value)
assert.Equals(t, dbNew.CreatedAt, dbc.CreatedAt)
assert.Equals(t, dbNew.Status, acme.StatusValid)
assert.Equals(t, dbNew.ValidatedAt, "foobar")
assert.Equals(t, dbNew.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
return nu, true, nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, tc.ch.ID, dbc.ID)
assert.Equals(t, tc.ch.AccountID, dbc.AccountID)
assert.Equals(t, tc.ch.AuthzID, dbc.AuthzID)
assert.Equals(t, tc.ch.Type, dbc.Type)
assert.Equals(t, tc.ch.Token, dbc.Token)
assert.Equals(t, tc.ch.Value, dbc.Value)
assert.Equals(t, tc.ch.ValidatedAt, "foobar")
assert.Equals(t, tc.ch.Status, acme.StatusValid)
assert.Equals(t, tc.ch.Error.Error(), acme.NewError(acme.ErrorMalformedType, "malformed").Error())
}
}
})
}
}

View file

@ -8,14 +8,19 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
nosqlDB "github.com/smallstep/nosql" "github.com/smallstep/nosql"
"github.com/smallstep/nosql/database"
) )
// dbNonce contains nonce metadata used in the ACME protocol. // dbNonce contains nonce metadata used in the ACME protocol.
type dbNonce struct { type dbNonce struct {
ID string ID string
Created time.Time CreatedAt time.Time
DeletedAt time.Time
}
func (dbn *dbNonce) clone() *dbNonce {
u := *dbn
return &u
} }
// CreateNonce creates, stores, and returns an ACME replay-nonce. // CreateNonce creates, stores, and returns an ACME replay-nonce.
@ -29,13 +34,9 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
id := base64.RawURLEncoding.EncodeToString([]byte(_id)) id := base64.RawURLEncoding.EncodeToString([]byte(_id))
n := &dbNonce{ n := &dbNonce{
ID: id, ID: id,
Created: clock.Now(), CreatedAt: clock.Now(),
} }
b, err := json.Marshal(n) if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
if err != nil {
return "", errors.Wrap(err, "error marshaling nonce")
}
if err = db.save(ctx, id, b, nil, "nonce", nonceTable); err != nil {
return "", err return "", err
} }
return acme.Nonce(id), nil return acme.Nonce(id), nil
@ -44,27 +45,24 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
// DeleteNonce verifies that the nonce is valid (by checking if it exists), // DeleteNonce verifies that the nonce is valid (by checking if it exists),
// and if so, consumes the nonce resource by deleting it from the database. // and if so, consumes the nonce resource by deleting it from the database.
func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error { func (db *DB) DeleteNonce(ctx context.Context, nonce acme.Nonce) error {
err := db.db.Update(&database.Tx{ id := string(nonce)
Operations: []*database.TxEntry{ b, err := db.db.Get(nonceTable, []byte(nonce))
{ if nosql.IsErrNotFound(err) {
Bucket: nonceTable, return errors.Wrapf(err, "nonce %s not found", id)
Key: []byte(nonce), } else if err != nil {
Cmd: database.Get, return errors.Wrapf(err, "error loading nonce %s", id)
}, }
{
Bucket: nonceTable,
Key: []byte(nonce),
Cmd: database.Delete,
},
},
})
switch { dbn := new(dbNonce)
case nosqlDB.IsErrNotFound(err): if err := json.Unmarshal(b, dbn); err != nil {
return errors.New("not found") return errors.Wrapf(err, "error unmarshaling nonce %s", string(nonce))
case err != nil:
return errors.Wrapf(err, "error deleting nonce %s", nonce)
default:
return nil
} }
if !dbn.DeletedAt.IsZero() {
return acme.NewError(acme.ErrorBadNonceType, "nonce %s already deleted", id)
}
nu := dbn.clone()
nu.DeletedAt = clock.Now()
return db.save(ctx, id, nu, dbn, "nonce", nonceTable)
} }

209
acme/db/nosql/nonce_test.go Normal file
View file

@ -0,0 +1,209 @@
package nosql
import (
"context"
"encoding/json"
"testing"
"time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/db"
"github.com/smallstep/nosql"
nosqldb "github.com/smallstep/nosql/database"
)
func TestDB_CreateNonce(t *testing.T) {
type test struct {
db nosql.DB
nonce *acme.Nonce
err error
_id *string
}
var tests = map[string]func(t *testing.T) test{
"fail/cmpAndSwap-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, old, nil)
dbn := new(dbNonce)
assert.FatalError(t, json.Unmarshal(nu, dbn))
assert.Equals(t, dbn.ID, string(key))
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
return nil, false, errors.New("force")
},
},
err: errors.New("error saving acme nonce: force"),
}
},
"ok": func(t *testing.T) test {
var (
id string
idPtr = &id
)
return test{
db: &db.MockNoSQLDB{
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
*idPtr = string(key)
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, old, nil)
dbn := new(dbNonce)
assert.FatalError(t, json.Unmarshal(nu, dbn))
assert.Equals(t, dbn.ID, string(key))
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.CreatedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbn.CreatedAt))
return nil, true, nil
},
},
_id: idPtr,
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if n, err := db.CreateNonce(context.Background()); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
if assert.Nil(t, tc.err) {
assert.Equals(t, string(n), *tc._id)
}
}
})
}
}
func TestDB_DeleteNonce(t *testing.T) {
nonceID := "nonceID"
type test struct {
db nosql.DB
err error
}
var tests = map[string]func(t *testing.T) test{
"fail/not-found": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, string(key), nonceID)
return nil, nosqldb.ErrNotFound
},
},
err: errors.New("nonce nonceID not found"),
}
},
"fail/db.Get-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, string(key), nonceID)
return nil, errors.Errorf("force")
},
},
err: errors.New("error loading nonce nonceID: force"),
}
},
"fail/unmarshal-error": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, string(key), nonceID)
a := []string{"foo", "bar", "baz"}
b, err := json.Marshal(a)
assert.FatalError(t, err)
return b, nil
},
},
err: errors.New("error unmarshaling nonce nonceID"),
}
},
"fail/already-used": func(t *testing.T) test {
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, string(key), nonceID)
nonce := dbNonce{
ID: nonceID,
CreatedAt: clock.Now().Add(-5 * time.Minute),
DeletedAt: clock.Now(),
}
b, err := json.Marshal(nonce)
assert.FatalError(t, err)
return b, nil
},
},
err: acme.NewError(acme.ErrorBadNonceType, "nonce already deleted"),
}
},
"ok": func(t *testing.T) test {
nonce := dbNonce{
ID: nonceID,
CreatedAt: clock.Now().Add(-5 * time.Minute),
}
b, err := json.Marshal(nonce)
assert.FatalError(t, err)
return test{
db: &db.MockNoSQLDB{
MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, string(key), nonceID)
return b, nil
},
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, nonceTable)
assert.Equals(t, old, b)
dbo := new(dbNonce)
assert.FatalError(t, json.Unmarshal(old, dbo))
assert.Equals(t, dbo.ID, string(key))
assert.True(t, clock.Now().Add(-6*time.Minute).Before(dbo.CreatedAt))
assert.True(t, clock.Now().Add(-4*time.Minute).After(dbo.CreatedAt))
assert.True(t, dbo.DeletedAt.IsZero())
dbn := new(dbNonce)
assert.FatalError(t, json.Unmarshal(nu, dbn))
assert.Equals(t, dbn.ID, string(key))
assert.True(t, clock.Now().Add(-6*time.Minute).Before(dbn.CreatedAt))
assert.True(t, clock.Now().Add(-4*time.Minute).After(dbn.CreatedAt))
assert.True(t, clock.Now().Add(-time.Minute).Before(dbn.DeletedAt))
assert.True(t, clock.Now().Add(time.Minute).After(dbn.DeletedAt))
return nil, true, nil
},
},
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
db := DB{db: tc.db}
if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error())
}
} else {
assert.Nil(t, tc.err)
}
})
}
}