From 0afea2e95771e2f641d2e5b54ac480c18180d0fd Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Fri, 8 Oct 2021 13:18:23 +0200 Subject: [PATCH] Improve tests for already bound EAB keys --- acme/account.go | 11 ++++++- acme/account_test.go | 65 ++++++++++++++++++++++++++++++++++++++++ acme/api/account.go | 6 +++- acme/api/account_test.go | 28 +++++++++-------- 4 files changed, 96 insertions(+), 14 deletions(-) diff --git a/acme/account.go b/acme/account.go index deaf57c8..14a707e9 100644 --- a/acme/account.go +++ b/acme/account.go @@ -43,6 +43,7 @@ func KeyToID(jwk *jose.JSONWebKey) (string, error) { return base64.RawURLEncoding.EncodeToString(kid), nil } +// ExternalAccountKey is an ACME External Account Binding key. type ExternalAccountKey struct { ID string `json:"id"` Provisioner string `json:"provisioner"` @@ -53,12 +54,20 @@ type ExternalAccountKey struct { BoundAt time.Time `json:"boundAt,omitempty"` } +// AlreadyBound returns whether this EAK is already bound to +// an ACME Account or not. func (eak *ExternalAccountKey) AlreadyBound() bool { return !eak.BoundAt.IsZero() } -func (eak *ExternalAccountKey) BindTo(account *Account) { +// BindTo binds the EAK to an Account. +// It returns an error if it's already bound. +func (eak *ExternalAccountKey) BindTo(account *Account) error { + if eak.AlreadyBound() { + return NewError(ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", eak.ID, eak.AccountID, eak.BoundAt) + } eak.AccountID = account.ID eak.BoundAt = time.Now() eak.KeyBytes = []byte{} // clearing the key bytes; can only be used once + return nil } diff --git a/acme/account_test.go b/acme/account_test.go index 5625c3dc..44b815b9 100644 --- a/acme/account_test.go +++ b/acme/account_test.go @@ -4,6 +4,7 @@ import ( "crypto" "encoding/base64" "testing" + "time" "github.com/pkg/errors" "github.com/smallstep/assert" @@ -79,3 +80,67 @@ func TestAccount_IsValid(t *testing.T) { }) } } + +func TestExternalAccountKey_BindTo(t *testing.T) { + boundAt := time.Now() + tests := []struct { + name string + eak *ExternalAccountKey + acct *Account + err *Error + }{ + { + name: "ok", + eak: &ExternalAccountKey{ + ID: "eakID", + Provisioner: "prov", + Reference: "ref", + KeyBytes: []byte{1, 3, 3, 7}, + }, + acct: &Account{ + ID: "accountID", + }, + err: nil, + }, + { + name: "fail/already-bound", + eak: &ExternalAccountKey{ + ID: "eakID", + Provisioner: "prov", + Reference: "ref", + KeyBytes: []byte{1, 3, 3, 7}, + AccountID: "someAccountID", + BoundAt: boundAt, + }, + acct: &Account{ + ID: "accountID", + }, + err: NewError(ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", "eakID", "someAccountID", boundAt), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + eak := tt.eak + acct := tt.acct + err := eak.BindTo(acct) + wantErr := tt.err != nil + gotErr := err != nil + if wantErr != gotErr { + t.Errorf("ExternalAccountKey.BindTo() error = %v, wantErr %v", err, tt.err) + } + if wantErr { + assert.NotNil(t, err) + assert.Type(t, &Error{}, err) + ae, _ := err.(*Error) + assert.Equals(t, ae.Type, tt.err.Type) + assert.Equals(t, ae.Detail, tt.err.Detail) + assert.Equals(t, ae.Identifier, tt.err.Identifier) + assert.Equals(t, ae.Subproblems, tt.err.Subproblems) + } else { + assert.Equals(t, eak.AccountID, acct.ID) + assert.Equals(t, eak.KeyBytes, []byte{}) + assert.NotNil(t, eak.BoundAt) + } + }) + } +} diff --git a/acme/api/account.go b/acme/api/account.go index 8d814d1c..877f5773 100644 --- a/acme/api/account.go +++ b/acme/api/account.go @@ -138,7 +138,11 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { return } if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response - eak.BindTo(acc) + err := eak.BindTo(acc) + if err != nil { + api.WriteError(w, err) + return + } if err := h.db.UpdateExternalAccountKey(ctx, prov.Name, eak); err != nil { api.WriteError(w, acme.WrapErrorISE(err, "error updating external account binding key")) return diff --git a/acme/api/account_test.go b/acme/api/account_test.go index bced48b2..0bedc5d1 100644 --- a/acme/api/account_test.go +++ b/acme/api/account_test.go @@ -1055,6 +1055,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, provisionerContextKey, prov) + createdAt := time.Now() return test{ db: &acme.MockDB{ MockGetExternalAccountKey: func(ctx context.Context, provisionerName string, keyID string) (*acme.ExternalAccountKey, error) { @@ -1063,7 +1064,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { Provisioner: escProvName, Reference: "testeak", KeyBytes: []byte{1, 3, 3, 7}, - CreatedAt: time.Now(), + CreatedAt: createdAt, }, nil }, }, @@ -1072,7 +1073,13 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { Contact: []string{"foo", "bar"}, ExternalAccountBinding: eab, }, - eak: &acme.ExternalAccountKey{}, + eak: &acme.ExternalAccountKey{ + ID: "eakID", + Provisioner: escProvName, + Reference: "testeak", + KeyBytes: []byte{1, 3, 3, 7}, + CreatedAt: createdAt, + }, err: nil, } }, @@ -1299,8 +1306,6 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { wantErr := tc.err != nil gotErr := err != nil if wantErr != gotErr { - // fmt.Println(got) - // fmt.Println(fmt.Sprintf("%#+v", got)) t.Errorf("Handler.validateExternalAccountBinding() error = %v, want %v", err, tc.err) } if wantErr { @@ -1311,20 +1316,19 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) { assert.Equals(t, ae.Detail, tc.err.Detail) assert.Equals(t, ae.Identifier, tc.err.Identifier) assert.Equals(t, ae.Subproblems, tc.err.Subproblems) - - // fmt.Println(fmt.Sprintf("%#+v", ae)) - // fmt.Println(fmt.Sprintf("%#+v", tc.err)) - - //t.Fail() } else { if got == nil { assert.Nil(t, tc.eak) } else { - // TODO: equality check on certain fields? assert.NotNil(t, tc.eak) + assert.Equals(t, got.ID, tc.eak.ID) + assert.Equals(t, got.KeyBytes, tc.eak.KeyBytes) + assert.Equals(t, got.Provisioner, tc.eak.Provisioner) + assert.Equals(t, got.Reference, tc.eak.Reference) + assert.Equals(t, got.CreatedAt, tc.eak.CreatedAt) + assert.Equals(t, got.AccountID, tc.eak.AccountID) + assert.Equals(t, got.BoundAt, tc.eak.BoundAt) } - //assert.Equals(t, tc.eak, got) - //assert.NotNil(t, got) } }) }