[acme db interface] add linker tests

This commit is contained in:
max furman 2021-03-15 10:30:12 -07:00
parent 8d2ebcfd49
commit 074ab7b221
6 changed files with 236 additions and 33 deletions

View file

@ -11,11 +11,11 @@ import (
// Account is a subset of the internal account type containing only those // Account is a subset of the internal account type containing only those
// attributes required for responses in the ACME protocol. // attributes required for responses in the ACME protocol.
type Account struct { type Account struct {
Contact []string `json:"contact,omitempty"` Contact []string `json:"contact,omitempty"`
Status Status `json:"status"` Status Status `json:"status"`
Orders string `json:"orders"` OrdersURL string `json:"orders"`
ID string `json:"-"` ID string `json:"-"`
Key *jose.JSONWebKey `json:"-"` Key *jose.JSONWebKey `json:"-"`
} }
// ToLog enables response logging. // ToLog enables response logging.

View file

@ -420,11 +420,11 @@ func TestHandler_NewAccount(t *testing.T) {
}, },
}, },
acc: &acme.Account{ acc: &acme.Account{
ID: "accountID", ID: "accountID",
Key: jwk, Key: jwk,
Status: acme.StatusValid, Status: acme.StatusValid,
Contact: []string{"foo", "bar"}, Contact: []string{"foo", "bar"},
Orders: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/account/accountID/orders", OrdersURL: "https://test.ca.smallstep.com/acme/test@acme-provisioner.com/account/accountID/orders",
}, },
ctx: ctx, ctx: ctx,
statusCode: 201, statusCode: 201,
@ -496,9 +496,9 @@ func TestHandler_NewAccount(t *testing.T) {
func TestHandler_GetUpdateAccount(t *testing.T) { func TestHandler_GetUpdateAccount(t *testing.T) {
accID := "accountID" accID := "accountID"
acc := acme.Account{ acc := acme.Account{
ID: accID, ID: accID,
Status: "valid", Status: "valid",
Orders: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID), OrdersURL: fmt.Sprintf("https://ca.smallstep.com/acme/account/%s/orders", accID),
} }
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())

View file

@ -160,7 +160,7 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
// LinkAccount sets the ACME links required by an ACME account. // LinkAccount sets the ACME links required by an ACME account.
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) {
acc.Orders = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID) acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, true, acc.ID)
} }
// LinkChallenge sets the ACME links required by an ACME challenge. // LinkChallenge sets the ACME links required by an ACME challenge.

View file

@ -7,9 +7,10 @@ import (
"testing" "testing"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
) )
func TestLinkerGetLink(t *testing.T) { func TestLinker_GetLink(t *testing.T) {
dns := "ca.smallstep.com" dns := "ca.smallstep.com"
prefix := "acme" prefix := "acme"
linker := NewLinker(dns, prefix) linker := NewLinker(dns, prefix)
@ -42,7 +43,7 @@ func TestLinkerGetLink(t *testing.T) {
assert.Equals(t, linker.GetLink(ctx, OrderLinkType, false, id), fmt.Sprintf("/%s/order/1234", provName)) assert.Equals(t, linker.GetLink(ctx, OrderLinkType, false, id), fmt.Sprintf("/%s/order/1234", provName))
} }
func TestLinkerGetLinkExplicit(t *testing.T) { func TestLinker_GetLinkExplicit(t *testing.T) {
dns := "ca.smallstep.com" dns := "ca.smallstep.com"
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prefix := "acme" prefix := "acme"
@ -91,9 +92,211 @@ func TestLinkerGetLinkExplicit(t *testing.T) {
assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, true, baseURL), fmt.Sprintf("%s/acme/%s/key-change", baseURL, provID))
assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID)) assert.Equals(t, linker.GetLinkExplicit(KeyChangeLinkType, provID, false, baseURL), fmt.Sprintf("/%s/key-change", provID))
assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, provID, id, id)) assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, true, baseURL, id, id), fmt.Sprintf("%s/acme/%s/challenge/%s/%s", baseURL, provID, id, id))
assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/challenge/%s/%s", provID, id, id)) assert.Equals(t, linker.GetLinkExplicit(ChallengeLinkType, provID, false, baseURL, id, id), fmt.Sprintf("/%s/challenge/%s/%s", provID, id, id))
assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID)) assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, true, baseURL, id), fmt.Sprintf("%s/acme/%s/certificate/1234", baseURL, provID))
assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID)) assert.Equals(t, linker.GetLinkExplicit(CertificateLinkType, provID, false, baseURL, id), fmt.Sprintf("/%s/certificate/1234", provID))
} }
func TestLinker_LinkOrder(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv()
provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
oid := "orderID"
certID := "certID"
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
o *acme.Order
validate func(o *acme.Order)
}
var tests = map[string]test{
"no-authz-and-no-cert": {
o: &acme.Order{
ID: oid,
},
validate: func(o *acme.Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{})
assert.Equals(t, o.CertificateURL, "")
},
},
"one-authz-and-cert": {
o: &acme.Order{
ID: oid,
CertificateID: certID,
AuthorizationIDs: []string{"foo"},
},
validate: func(o *acme.Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
})
assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID))
},
},
"many-authz": {
o: &acme.Order{
ID: oid,
CertificateID: certID,
AuthorizationIDs: []string{"foo", "bar", "zap"},
},
validate: func(o *acme.Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "bar"),
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "zap"),
})
assert.Equals(t, o.CertificateURL, fmt.Sprintf("%s/%s/%s/certificate/%s", baseURL, linkerPrefix, provName, certID))
},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
l.LinkOrder(ctx, tc.o)
tc.validate(tc.o)
})
}
}
func TestLinker_LinkAccount(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv()
provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
accID := "accountID"
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
a *acme.Account
validate func(o *acme.Account)
}
var tests = map[string]test{
"ok": {
a: &acme.Account{
ID: accID,
},
validate: func(a *acme.Account) {
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
l.LinkAccount(ctx, tc.a)
tc.validate(tc.a)
})
}
}
func TestLinker_LinkChallenge(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv()
provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
chID := "chID"
azID := "azID"
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
ch *acme.Challenge
validate func(o *acme.Challenge)
}
var tests = map[string]test{
"ok": {
ch: &acme.Challenge{
ID: chID,
AuthzID: azID,
},
validate: func(ch *acme.Challenge) {
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, ch.AuthzID, ch.ID))
},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
l.LinkChallenge(ctx, tc.ch)
tc.validate(tc.ch)
})
}
}
func TestLinker_LinkAuthorization(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv()
provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
chID0 := "chID-0"
chID1 := "chID-1"
chID2 := "chID-2"
azID := "azID"
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
az *acme.Authorization
validate func(o *acme.Authorization)
}
var tests = map[string]test{
"ok": {
az: &acme.Authorization{
ID: azID,
Challenges: []*acme.Challenge{
{ID: chID0, AuthzID: azID},
{ID: chID1, AuthzID: azID},
{ID: chID2, AuthzID: azID},
},
},
validate: func(az *acme.Authorization) {
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
l.LinkAuthorization(ctx, tc.az)
tc.validate(tc.az)
})
}
}
func TestLinker_LinkOrdersByAccountID(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv()
provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix)
type test struct {
oids []string
}
var tests = map[string]test{
"ok": {
oids: []string{"foo", "bar", "baz"},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
l.LinkOrdersByAccountID(ctx, tc.oids)
assert.Equals(t, tc.oids, []string{
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "foo"),
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "bar"),
fmt.Sprintf("%s/%s/%s/order/%s", baseURL, linkerPrefix, provName, "baz"),
})
})
}
}

View file

@ -320,7 +320,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) {
if c.acc == nil { if c.acc == nil {
return nil, errors.New("acme client not configured with account") return nil, errors.New("acme client not configured with account")
} }
resp, err := c.post(nil, c.acc.Orders, withKid(c)) resp, err := c.post(nil, c.acc.OrdersURL, withKid(c))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -330,7 +330,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) {
var orders []string var orders []string
if err := readJSON(resp.Body, &orders); err != nil { if err := readJSON(resp.Body, &orders); err != nil {
return nil, errors.Wrapf(err, "error reading %s", c.acc.Orders) return nil, errors.Wrapf(err, "error reading %s", c.acc.OrdersURL)
} }
return orders, nil return orders, nil

View file

@ -40,9 +40,9 @@ func TestNewACMEClient(t *testing.T) {
KeyChange: srv.URL + "/blorp", KeyChange: srv.URL + "/blorp",
} }
acc := acme.Account{ acc := acme.Account{
Contact: []string{"max", "mariano"}, Contact: []string{"max", "mariano"},
Status: "valid", Status: "valid",
Orders: "orders-url", OrdersURL: "orders-url",
} }
tests := map[string]func(t *testing.T) test{ tests := map[string]func(t *testing.T) test{
"fail/client-option-error": func(t *testing.T) test { "fail/client-option-error": func(t *testing.T) test {
@ -248,9 +248,9 @@ func TestACMEClient_post(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
acc := acme.Account{ acc := acme.Account{
Contact: []string{"max", "mariano"}, Contact: []string{"max", "mariano"},
Status: "valid", Status: "valid",
Orders: "orders-url", OrdersURL: "orders-url",
} }
ac := &ACMEClient{ ac := &ACMEClient{
client: &http.Client{ client: &http.Client{
@ -1121,9 +1121,9 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
Key: jwk, Key: jwk,
kid: "foobar", kid: "foobar",
acc: &acme.Account{ acc: &acme.Account{
Contact: []string{"max", "mariano"}, Contact: []string{"max", "mariano"},
Status: "valid", Status: "valid",
Orders: srv.URL + "/orders-url", OrdersURL: srv.URL + "/orders-url",
}, },
} }
@ -1198,7 +1198,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
assert.Equals(t, hdr.Nonce, expectedNonce) assert.Equals(t, hdr.Nonce, expectedNonce)
jwsURL, ok := hdr.ExtraHeaders["url"].(string) jwsURL, ok := hdr.ExtraHeaders["url"].(string)
assert.Fatal(t, ok) assert.Fatal(t, ok)
assert.Equals(t, jwsURL, ac.acc.Orders) assert.Equals(t, jwsURL, ac.acc.OrdersURL)
assert.Equals(t, hdr.KeyID, ac.kid) assert.Equals(t, hdr.KeyID, ac.kid)
payload, err := jws.Verify(ac.Key.Public()) payload, err := jws.Verify(ac.Key.Public())
@ -1259,9 +1259,9 @@ func TestACMEClient_GetCertificate(t *testing.T) {
Key: jwk, Key: jwk,
kid: "foobar", kid: "foobar",
acc: &acme.Account{ acc: &acme.Account{
Contact: []string{"max", "mariano"}, Contact: []string{"max", "mariano"},
Status: "valid", Status: "valid",
Orders: srv.URL + "/orders-url", OrdersURL: srv.URL + "/orders-url",
}, },
} }