Fix unit tests - work in progress

This commit is contained in:
Mariano Cano 2022-04-27 19:08:16 -07:00
parent 42435ace64
commit bb8d85a201
6 changed files with 65 additions and 64 deletions

View file

@ -315,11 +315,11 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
// h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetOrdersByAccountID(w, req)
GetOrdersByAccountID(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -759,11 +759,11 @@ func TestHandler_NewAccount(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
// h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.NewAccount(w, req)
NewAccount(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -959,11 +959,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
// h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetOrUpdateAccount(w, req)
GetOrUpdateAccount(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)

View file

@ -762,10 +762,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
db: tc.db,
}
got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar)
// h := &Handler{
// db: tc.db,
// }
got, err := validateExternalAccountBinding(tc.ctx, tc.nar)
wantErr := tc.err != nil
gotErr := err != nil
if wantErr != gotErr {

View file

@ -38,10 +38,10 @@ func TestHandler_GetNonce(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &Handler{}
// h := &Handler{}
w := httptest.NewRecorder()
req.Method = tt.name
h.GetNonce(w, req)
GetNonce(w, req)
res := w.Result()
if res.StatusCode != tt.statusCode {
@ -53,6 +53,7 @@ func TestHandler_GetNonce(t *testing.T) {
func TestHandler_GetDirectory(t *testing.T) {
linker := NewLinker("ca.smallstep.com", "acme")
_ = linker
type test struct {
ctx context.Context
statusCode int
@ -130,11 +131,11 @@ func TestHandler_GetDirectory(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: linker}
// h := &Handler{linker: linker}
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetDirectory(w, req)
GetDirectory(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -304,11 +305,11 @@ func TestHandler_GetAuthorization(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
// h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetAuthorization(w, req)
GetAuthorization(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -447,11 +448,11 @@ func TestHandler_GetCertificate(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db}
// h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetCertificate(w, req)
GetCertificate(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -703,11 +704,11 @@ func TestHandler_GetChallenge(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
// h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetChallenge(w, req)
GetChallenge(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)

View file

@ -70,7 +70,7 @@ func Test_baseURLFromRequest(t *testing.T) {
if tc.requestPreparer != nil {
tc.requestPreparer(request)
}
result := baseURLFromRequest(request)
result := getBaseURLFromRequest(request)
if result == nil || tc.expectedResult == nil {
assert.Equals(t, result, tc.expectedResult)
} else if result.String() != tc.expectedResult.String() {
@ -81,7 +81,7 @@ func Test_baseURLFromRequest(t *testing.T) {
}
func TestHandler_baseURLFromRequest(t *testing.T) {
h := &Handler{}
// h := &Handler{}
req := httptest.NewRequest("GET", "/foo", nil)
req.Host = "test.ca.smallstep.com:8080"
w := httptest.NewRecorder()
@ -94,7 +94,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) {
}
}
h.baseURLFromRequest(next)(w, req)
baseURLFromRequest(next)(w, req)
req = httptest.NewRequest("GET", "/foo", nil)
req.Host = ""
@ -103,7 +103,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) {
assert.Equals(t, baseURLFromContext(r.Context()), nil)
}
h.baseURLFromRequest(next)(w, req)
baseURLFromRequest(next)(w, req)
}
func TestHandler_addNonce(t *testing.T) {
@ -139,10 +139,10 @@ func TestHandler_addNonce(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db}
// h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", u, nil)
w := httptest.NewRecorder()
h.addNonce(testNext)(w, req)
addNonce(testNext)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -195,11 +195,11 @@ func TestHandler_addDirLink(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: tc.linker}
// h := &Handler{linker: tc.linker}
req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.addDirLink(testNext)(w, req)
addDirLink(testNext)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -242,7 +242,7 @@ func TestHandler_verifyContentType(t *testing.T) {
"fail/provisioner-not-set": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
// linker: NewLinker("dns", "acme"),
},
url: u,
ctx: context.Background(),
@ -254,7 +254,7 @@ func TestHandler_verifyContentType(t *testing.T) {
"fail/general-bad-content-type": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
// linker: NewLinker("dns", "acme"),
},
url: u,
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
@ -266,7 +266,7 @@ func TestHandler_verifyContentType(t *testing.T) {
"fail/certificate-bad-content-type": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
// linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "foo",
@ -277,7 +277,7 @@ func TestHandler_verifyContentType(t *testing.T) {
"ok": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
// linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/jose+json",
@ -287,7 +287,7 @@ func TestHandler_verifyContentType(t *testing.T) {
"ok/certificate/pkix-cert": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
// linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/pkix-cert",
@ -297,7 +297,7 @@ func TestHandler_verifyContentType(t *testing.T) {
"ok/certificate/jose+json": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
// linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/jose+json",
@ -307,7 +307,7 @@ func TestHandler_verifyContentType(t *testing.T) {
"ok/certificate/pkcs7-mime": func(t *testing.T) test {
return test{
h: Handler{
linker: NewLinker("dns", "acme"),
// linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/pkcs7-mime",
@ -326,7 +326,7 @@ func TestHandler_verifyContentType(t *testing.T) {
req = req.WithContext(tc.ctx)
req.Header.Add("Content-Type", tc.contentType)
w := httptest.NewRecorder()
tc.h.verifyContentType(testNext)(w, req)
verifyContentType(testNext)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -390,11 +390,11 @@ func TestHandler_isPostAsGet(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{}
// h := &Handler{}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.isPostAsGet(testNext)(w, req)
isPostAsGet(testNext)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -481,10 +481,10 @@ func TestHandler_parseJWS(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{}
// h := &Handler{}
req := httptest.NewRequest("GET", u, tc.body)
w := httptest.NewRecorder()
h.parseJWS(tc.next)(w, req)
parseJWS(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -679,11 +679,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{}
// h := &Handler{}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.verifyAndExtractJWSPayload(tc.next)(w, req)
verifyAndExtractJWSPayload(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -881,11 +881,11 @@ func TestHandler_lookupJWK(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker}
// h := &Handler{db: tc.db, linker: tc.linker}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.lookupJWK(tc.next)(w, req)
lookupJWK(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1077,11 +1077,11 @@ func TestHandler_extractJWK(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db}
// h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.extractJWK(tc.next)(w, req)
extractJWK(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1444,11 +1444,11 @@ func TestHandler_validateJWS(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db}
// h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.validateJWS(tc.next)(w, req)
validateJWS(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1628,11 +1628,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker}
// h := &Handler{db: tc.db, linker: tc.linker}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.extractOrLookupJWK(tc.next)(w, req)
extractOrLookupJWK(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1717,11 +1717,11 @@ func TestHandler_checkPrerequisites(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
// h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.checkPrerequisites(tc.next)(w, req)
checkPrerequisites(tc.next)(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)

View file

@ -421,11 +421,11 @@ func TestHandler_GetOrder(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
// h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.GetOrder(w, req)
GetOrder(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -636,8 +636,8 @@ func TestHandler_newAuthorization(t *testing.T) {
for name, run := range tests {
t.Run(name, func(t *testing.T) {
tc := run(t)
h := &Handler{db: tc.db}
if err := h.newAuthorization(context.Background(), tc.az); err != nil {
// h := &Handler{db: tc.db}
if err := newAuthorization(context.Background(), tc.az); err != nil {
if assert.NotNil(t, tc.err) {
switch k := err.(type) {
case *acme.Error:
@ -1334,11 +1334,11 @@ func TestHandler_NewOrder(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
// h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.NewOrder(w, req)
NewOrder(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1624,11 +1624,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
// h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.FinalizeOrder(w, req)
FinalizeOrder(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)

View file

@ -1057,11 +1057,11 @@ func TestHandler_RevokeCert(t *testing.T) {
for name, setup := range tests {
tc := setup(t)
t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca}
// h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca}
req := httptest.NewRequest("POST", revokeURL, nil)
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder()
h.RevokeCert(w, req)
RevokeCert(w, req)
res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1198,8 +1198,8 @@ func TestHandler_isAccountAuthorized(t *testing.T) {
for name, setup := range tests {
tc := setup(t)
t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db}
acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
// h := &Handler{db: tc.db}
acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
expectError := tc.err != nil
gotError := acmeErr != nil