Introduce gocritic linter and address warnings

This commit is contained in:
max furman 2021-10-08 14:59:57 -04:00
parent 9cb1f213d8
commit 933b40a02a
88 changed files with 699 additions and 742 deletions

View file

@ -36,22 +36,30 @@ linters-settings:
- performance - performance
- style - style
- experimental - experimental
- diagnostic
disabled-checks: disabled-checks:
- wrapperFunc - commentFormatting
- dupImport # https://github.com/go-critic/go-critic/issues/845 - commentedOutCode
- evalOrder
- hugeParam
- octalLiteral
- rangeValCopy
- tooManyResultsChecker
- unnamedResult
linters: linters:
disable-all: true disable-all: true
enable: enable:
- gofmt
- revive
- govet
- misspell
- ineffassign
- deadcode - deadcode
- gocritic
- gofmt
- gosimple
- govet
- ineffassign
- misspell
- revive
- staticcheck - staticcheck
- unused - unused
- gosimple
run: run:
skip-dirs: skip-dirs:

View file

@ -19,7 +19,7 @@ type NewAccountRequest struct {
func validateContacts(cs []string) error { func validateContacts(cs []string) error {
for _, c := range cs { for _, c := range cs {
if len(c) == 0 { if c == "" {
return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string") return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string")
} }
} }

View file

@ -178,7 +178,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID) u := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID)
oids := []string{"foo", "bar"} oids := []string{"foo", "bar"}
oidURLs := []string{ oidURLs := []string{
@ -255,7 +255,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetOrdersByAccountID(w, req) h.GetOrdersByAccountID(w, req)

View file

@ -148,7 +148,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("authzID", az.ID) chiCtx.URLParams.Add("authzID", az.ID)
url := fmt.Sprintf("%s/acme/%s/authz/%s", u := fmt.Sprintf("%s/acme/%s/authz/%s",
baseURL.String(), provName, az.ID) baseURL.String(), provName, az.ID)
type test struct { type test struct {
@ -280,7 +280,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
expB, err := json.Marshal(az) expB, err := json.Marshal(az)
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Location"], []string{u})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })
@ -314,7 +314,7 @@ func TestHandler_GetCertificate(t *testing.T) {
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("certID", certID) chiCtx.URLParams.Add("certID", certID)
url := fmt.Sprintf("%s/acme/%s/certificate/%s", u := fmt.Sprintf("%s/acme/%s/certificate/%s",
baseURL.String(), provName, certID) baseURL.String(), provName, certID)
type test struct { type test struct {
@ -396,7 +396,7 @@ func TestHandler_GetCertificate(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetCertificate(w, req) h.GetCertificate(w, req)
@ -434,7 +434,7 @@ func TestHandler_GetChallenge(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/challenge/%s/%s", u := fmt.Sprintf("%s/acme/%s/challenge/%s/%s",
baseURL.String(), provName, "authzID", "chID") baseURL.String(), provName, "authzID", "chID")
type test struct { type test struct {
@ -635,7 +635,7 @@ func TestHandler_GetChallenge(t *testing.T) {
AuthorizationID: "authzID", AuthorizationID: "authzID",
Type: acme.HTTP01, Type: acme.HTTP01,
AccountID: "accID", AccountID: "accID",
URL: url, URL: u,
Error: acme.NewError(acme.ErrorConnectionType, "force"), Error: acme.NewError(acme.ErrorConnectionType, "force"),
}, },
vco: &acme.ValidateChallengeOptions{ vco: &acme.ValidateChallengeOptions{
@ -652,7 +652,7 @@ func TestHandler_GetChallenge(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetChallenge(w, req) h.GetChallenge(w, req)
@ -678,7 +678,7 @@ func TestHandler_GetChallenge(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")}) assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")})
assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Location"], []string{u})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })

View file

@ -223,7 +223,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive")) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
return return
} }
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 { if hdr.JSONWebKey == nil && hdr.KeyID == "" {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header")) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
return return
} }
@ -367,7 +367,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm { if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match")) api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
return return
} }

View file

@ -108,7 +108,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) {
} }
func TestHandler_addNonce(t *testing.T) { func TestHandler_addNonce(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-nonce" u := "https://ca.smallstep.com/acme/new-nonce"
type test struct { type test struct {
db acme.DB db acme.DB
err *acme.Error err *acme.Error
@ -141,7 +141,7 @@ func TestHandler_addNonce(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.addNonce(testNext)(w, req) h.addNonce(testNext)(w, req)
res := w.Result() res := w.Result()
@ -230,7 +230,7 @@ func TestHandler_verifyContentType(t *testing.T) {
prov := newProv() prov := newProv()
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
type test struct { type test struct {
h Handler h Handler
ctx context.Context ctx context.Context
@ -245,7 +245,7 @@ func TestHandler_verifyContentType(t *testing.T) {
h: Handler{ h: Handler{
linker: NewLinker("dns", "acme"), linker: NewLinker("dns", "acme"),
}, },
url: url, url: u,
ctx: context.Background(), ctx: context.Background(),
contentType: "foo", contentType: "foo",
statusCode: 500, statusCode: 500,
@ -257,7 +257,7 @@ func TestHandler_verifyContentType(t *testing.T) {
h: Handler{ h: Handler{
linker: NewLinker("dns", "acme"), linker: NewLinker("dns", "acme"),
}, },
url: url, url: u,
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "foo", contentType: "foo",
statusCode: 400, statusCode: 400,
@ -319,11 +319,11 @@ func TestHandler_verifyContentType(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
_url := url _u := u
if tc.url != "" { if tc.url != "" {
_url = tc.url _u = tc.url
} }
req := httptest.NewRequest("GET", _url, nil) req := httptest.NewRequest("GET", _u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
req.Header.Add("Content-Type", tc.contentType) req.Header.Add("Content-Type", tc.contentType)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -353,7 +353,7 @@ func TestHandler_verifyContentType(t *testing.T) {
} }
func TestHandler_isPostAsGet(t *testing.T) { func TestHandler_isPostAsGet(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-account" u := "https://ca.smallstep.com/acme/new-account"
type test struct { type test struct {
ctx context.Context ctx context.Context
err *acme.Error err *acme.Error
@ -392,7 +392,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} h := &Handler{}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.isPostAsGet(testNext)(w, req) h.isPostAsGet(testNext)(w, req)
@ -430,7 +430,7 @@ func (errReader) Close() error {
} }
func TestHandler_parseJWS(t *testing.T) { func TestHandler_parseJWS(t *testing.T) {
url := "https://ca.smallstep.com/acme/new-account" u := "https://ca.smallstep.com/acme/new-account"
type test struct { type test struct {
next nextHTTP next nextHTTP
body io.Reader body io.Reader
@ -483,7 +483,7 @@ func TestHandler_parseJWS(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} h := &Handler{}
req := httptest.NewRequest("GET", url, tc.body) req := httptest.NewRequest("GET", u, tc.body)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.parseJWS(tc.next)(w, req) h.parseJWS(tc.next)(w, req)
res := w.Result() res := w.Result()
@ -528,7 +528,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
url := "https://ca.smallstep.com/acme/account/1234" u := "https://ca.smallstep.com/acme/account/1234"
type test struct { type test struct {
ctx context.Context ctx context.Context
next func(http.ResponseWriter, *http.Request) next func(http.ResponseWriter, *http.Request)
@ -681,7 +681,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} h := &Handler{}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.verifyAndExtractJWSPayload(tc.next)(w, req) h.verifyAndExtractJWSPayload(tc.next)(w, req)
@ -713,7 +713,7 @@ func TestHandler_lookupJWK(t *testing.T) {
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/account/1234", u := fmt.Sprintf("%s/acme/%s/account/1234",
baseURL, provName) baseURL, provName)
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -883,7 +883,7 @@ func TestHandler_lookupJWK(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker} h := &Handler{db: tc.db, linker: tc.linker}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.lookupJWK(tc.next)(w, req) h.lookupJWK(tc.next)(w, req)
@ -934,7 +934,7 @@ func TestHandler_extractJWK(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234", u := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234",
provName) provName)
type test struct { type test struct {
db acme.DB db acme.DB
@ -1079,7 +1079,7 @@ func TestHandler_extractJWK(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.extractJWK(tc.next)(w, req) h.extractJWK(tc.next)(w, req)
@ -1108,7 +1108,7 @@ func TestHandler_extractJWK(t *testing.T) {
} }
func TestHandler_validateJWS(t *testing.T) { func TestHandler_validateJWS(t *testing.T) {
url := "https://ca.smallstep.com/acme/account/1234" u := "https://ca.smallstep.com/acme/account/1234"
type test struct { type test struct {
db acme.DB db acme.DB
ctx context.Context ctx context.Context
@ -1198,7 +1198,7 @@ func TestHandler_validateJWS(t *testing.T) {
Algorithm: jose.RS256, Algorithm: jose.RS256,
JSONWebKey: &pub, JSONWebKey: &pub,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1226,7 +1226,7 @@ func TestHandler_validateJWS(t *testing.T) {
Algorithm: jose.RS256, Algorithm: jose.RS256,
JSONWebKey: &pub, JSONWebKey: &pub,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1298,7 +1298,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", url), err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", u),
} }
}, },
"fail/both-jwk-kid": func(t *testing.T) test { "fail/both-jwk-kid": func(t *testing.T) test {
@ -1313,7 +1313,7 @@ func TestHandler_validateJWS(t *testing.T) {
KeyID: "bar", KeyID: "bar",
JSONWebKey: &pub, JSONWebKey: &pub,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1337,7 +1337,7 @@ func TestHandler_validateJWS(t *testing.T) {
Protected: jose.Header{ Protected: jose.Header{
Algorithm: jose.ES256, Algorithm: jose.ES256,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1362,7 +1362,7 @@ func TestHandler_validateJWS(t *testing.T) {
Algorithm: jose.ES256, Algorithm: jose.ES256,
KeyID: "bar", KeyID: "bar",
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1392,7 +1392,7 @@ func TestHandler_validateJWS(t *testing.T) {
Algorithm: jose.ES256, Algorithm: jose.ES256,
JSONWebKey: &pub, JSONWebKey: &pub,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1422,7 +1422,7 @@ func TestHandler_validateJWS(t *testing.T) {
Algorithm: jose.RS256, Algorithm: jose.RS256,
JSONWebKey: &pub, JSONWebKey: &pub,
ExtraHeaders: map[jose.HeaderKey]interface{}{ ExtraHeaders: map[jose.HeaderKey]interface{}{
"url": url, "url": u,
}, },
}, },
}, },
@ -1446,7 +1446,7 @@ func TestHandler_validateJWS(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} h := &Handler{db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.validateJWS(tc.next)(w, req) h.validateJWS(tc.next)(w, req)

View file

@ -264,7 +264,7 @@ func TestHandler_GetOrder(t *testing.T) {
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("ordID", o.ID) chiCtx.URLParams.Add("ordID", o.ID)
url := fmt.Sprintf("%s/acme/%s/order/%s", u := fmt.Sprintf("%s/acme/%s/order/%s",
baseURL.String(), escProvName, o.ID) baseURL.String(), escProvName, o.ID)
type test struct { type test struct {
@ -422,7 +422,7 @@ func TestHandler_GetOrder(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetOrder(w, req) h.GetOrder(w, req)
@ -448,7 +448,7 @@ func TestHandler_GetOrder(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Location"], []string{u})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })
@ -663,7 +663,7 @@ func TestHandler_NewOrder(t *testing.T) {
prov := newProv() prov := newProv()
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
url := fmt.Sprintf("%s/acme/%s/order/ordID", u := fmt.Sprintf("%s/acme/%s/order/ordID",
baseURL.String(), escProvName) baseURL.String(), escProvName)
type test struct { type test struct {
@ -1335,7 +1335,7 @@ func TestHandler_NewOrder(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.NewOrder(w, req) h.NewOrder(w, req)
@ -1363,7 +1363,7 @@ func TestHandler_NewOrder(t *testing.T) {
tc.vr(t, ro) tc.vr(t, ro)
} }
assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Location"], []string{u})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })
@ -1406,7 +1406,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
// Request with chi context // Request with chi context
chiCtx := chi.NewRouteContext() chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("ordID", o.ID) chiCtx.URLParams.Add("ordID", o.ID)
url := fmt.Sprintf("%s/acme/%s/order/%s", u := fmt.Sprintf("%s/acme/%s/order/%s",
baseURL.String(), escProvName, o.ID) baseURL.String(), escProvName, o.ID)
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr") _csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
@ -1625,7 +1625,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
req := httptest.NewRequest("GET", url, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.FinalizeOrder(w, req) h.FinalizeOrder(w, req)
@ -1654,7 +1654,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
assert.FatalError(t, json.Unmarshal(body, ro)) assert.FatalError(t, json.Unmarshal(body, ro))
assert.Equals(t, bytes.TrimSpace(body), expB) assert.Equals(t, bytes.TrimSpace(body), expB)
assert.Equals(t, res.Header["Location"], []string{url}) assert.Equals(t, res.Header["Location"], []string{u})
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"}) assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
} }
}) })

View file

@ -76,23 +76,23 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
} }
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
url := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
resp, err := vo.HTTPGet(url.String()) resp, err := vo.HTTPGet(u.String())
if err != nil { if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
"error doing http GET for url %s", url)) "error doing http GET for url %s", u))
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return storeError(ctx, db, ch, false, NewError(ErrorConnectionType, return storeError(ctx, db, ch, false, NewError(ErrorConnectionType,
"error doing http GET for url %s with status code %d", url, resp.StatusCode)) "error doing http GET for url %s with status code %d", u, resp.StatusCode))
} }
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return WrapErrorISE(err, "error reading "+ return WrapErrorISE(err, "error reading "+
"response body for url %s", url) "response body for url %s", u)
} }
keyAuth := strings.TrimSpace(string(body)) keyAuth := strings.TrimSpace(string(body))

View file

@ -1276,7 +1276,7 @@ func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, na
oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1} oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
} }
keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash[:]) keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash)
certTemplate.ExtraExtensions = []pkix.Extension{ certTemplate.ExtraExtensions = []pkix.Extension{
{ {

View file

@ -93,8 +93,8 @@ func TestDB_getDBAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil { if dbacc, err := d.getDBAccount(context.Background(), accID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -109,15 +109,13 @@ func TestDB_getDBAccount(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, dbacc.ID, tc.dbacc.ID)
assert.Equals(t, dbacc.ID, tc.dbacc.ID) assert.Equals(t, dbacc.Status, tc.dbacc.Status)
assert.Equals(t, dbacc.Status, tc.dbacc.Status) assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt) assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt)
assert.Equals(t, dbacc.DeactivatedAt, tc.dbacc.DeactivatedAt) assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact) assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
}
} }
}) })
} }
@ -174,8 +172,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil { if retAccID, err := d.getAccountIDByKeyID(context.Background(), kid); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -190,10 +188,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, retAccID, accID)
assert.Equals(t, retAccID, accID)
}
} }
}) })
} }
@ -250,8 +246,8 @@ func TestDB_GetAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if acc, err := db.GetAccount(context.Background(), accID); err != nil { if acc, err := d.GetAccount(context.Background(), accID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -266,13 +262,11 @@ func TestDB_GetAccount(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, acc.ID, tc.dbacc.ID)
assert.Equals(t, acc.ID, tc.dbacc.ID) assert.Equals(t, acc.Status, tc.dbacc.Status)
assert.Equals(t, acc.Status, tc.dbacc.Status) assert.Equals(t, acc.Contact, tc.dbacc.Contact)
assert.Equals(t, acc.Contact, tc.dbacc.Contact) assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
}
} }
}) })
} }
@ -358,8 +352,8 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil { if acc, err := d.GetAccountByKeyID(context.Background(), kid); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -374,13 +368,11 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, acc.ID, tc.dbacc.ID)
assert.Equals(t, acc.ID, tc.dbacc.ID) assert.Equals(t, acc.Status, tc.dbacc.Status)
assert.Equals(t, acc.Status, tc.dbacc.Status) assert.Equals(t, acc.Contact, tc.dbacc.Contact)
assert.Equals(t, acc.Contact, tc.dbacc.Contact) assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
}
} }
}) })
} }
@ -527,8 +519,8 @@ func TestDB_CreateAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.CreateAccount(context.Background(), tc.acc); err != nil { if err := d.CreateAccount(context.Background(), tc.acc); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -688,8 +680,8 @@ func TestDB_UpdateAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.UpdateAccount(context.Background(), tc.acc); err != nil { if err := d.UpdateAccount(context.Background(), tc.acc); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -97,8 +97,8 @@ func TestDB_getDBAuthz(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if dbaz, err := db.getDBAuthz(context.Background(), azID); err != nil { if dbaz, err := d.getDBAuthz(context.Background(), azID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -113,18 +113,16 @@ func TestDB_getDBAuthz(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, dbaz.ID, tc.dbaz.ID)
assert.Equals(t, dbaz.ID, tc.dbaz.ID) assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID) assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier) assert.Equals(t, dbaz.Status, tc.dbaz.Status)
assert.Equals(t, dbaz.Status, tc.dbaz.Status) assert.Equals(t, dbaz.Token, tc.dbaz.Token)
assert.Equals(t, dbaz.Token, tc.dbaz.Token) assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt)
assert.Equals(t, dbaz.CreatedAt, tc.dbaz.CreatedAt) assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt)
assert.Equals(t, dbaz.ExpiresAt, tc.dbaz.ExpiresAt) assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error()) assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
}
} }
}) })
} }
@ -293,8 +291,8 @@ func TestDB_GetAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if az, err := db.GetAuthorization(context.Background(), azID); err != nil { if az, err := d.GetAuthorization(context.Background(), azID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -309,21 +307,19 @@ func TestDB_GetAuthorization(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, az.ID, tc.dbaz.ID)
assert.Equals(t, az.ID, tc.dbaz.ID) assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
assert.Equals(t, az.AccountID, tc.dbaz.AccountID) assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
assert.Equals(t, az.Identifier, tc.dbaz.Identifier) assert.Equals(t, az.Status, tc.dbaz.Status)
assert.Equals(t, az.Status, tc.dbaz.Status) assert.Equals(t, az.Token, tc.dbaz.Token)
assert.Equals(t, az.Token, tc.dbaz.Token) assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard)
assert.Equals(t, az.Wildcard, tc.dbaz.Wildcard) assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt)
assert.Equals(t, az.ExpiresAt, tc.dbaz.ExpiresAt) assert.Equals(t, az.Challenges, []*acme.Challenge{
assert.Equals(t, az.Challenges, []*acme.Challenge{ {ID: "foo"},
{ID: "foo"}, {ID: "bar"},
{ID: "bar"}, })
}) assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
}
} }
}) })
} }
@ -445,8 +441,8 @@ func TestDB_CreateAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.CreateAuthorization(context.Background(), tc.az); err != nil { if err := d.CreateAuthorization(context.Background(), tc.az); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -594,8 +590,8 @@ func TestDB_UpdateAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.UpdateAuthorization(context.Background(), tc.az); err != nil { if err := d.UpdateAuthorization(context.Background(), tc.az); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -98,8 +98,8 @@ func TestDB_CreateCertificate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.CreateCertificate(context.Background(), tc.cert); err != nil { if err := d.CreateCertificate(context.Background(), tc.cert); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -228,8 +228,8 @@ func TestDB_GetCertificate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
cert, err := db.GetCertificate(context.Background(), certID) cert, err := d.GetCertificate(context.Background(), certID)
if err != nil { if err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
@ -245,14 +245,12 @@ func TestDB_GetCertificate(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, cert.ID, certID)
assert.Equals(t, cert.ID, certID) assert.Equals(t, cert.AccountID, "accountID")
assert.Equals(t, cert.AccountID, "accountID") assert.Equals(t, cert.OrderID, "orderID")
assert.Equals(t, cert.OrderID, "orderID") assert.Equals(t, cert.Leaf, leaf)
assert.Equals(t, cert.Leaf, leaf) assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
}
} }
}) })
} }

View file

@ -92,8 +92,8 @@ func TestDB_getDBChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if ch, err := db.getDBChallenge(context.Background(), chID); err != nil { if ch, err := d.getDBChallenge(context.Background(), chID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -108,17 +108,15 @@ func TestDB_getDBChallenge(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, ch.ID, tc.dbc.ID)
assert.Equals(t, ch.ID, tc.dbc.ID) assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
assert.Equals(t, ch.AccountID, tc.dbc.AccountID) assert.Equals(t, ch.Type, tc.dbc.Type)
assert.Equals(t, ch.Type, tc.dbc.Type) assert.Equals(t, ch.Status, tc.dbc.Status)
assert.Equals(t, ch.Status, tc.dbc.Status) assert.Equals(t, ch.Token, tc.dbc.Token)
assert.Equals(t, ch.Token, tc.dbc.Token) assert.Equals(t, ch.Value, tc.dbc.Value)
assert.Equals(t, ch.Value, tc.dbc.Value) assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
}
} }
}) })
} }
@ -206,8 +204,8 @@ func TestDB_CreateChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.CreateChallenge(context.Background(), tc.ch); err != nil { if err := d.CreateChallenge(context.Background(), tc.ch); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -286,8 +284,8 @@ func TestDB_GetChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil { if ch, err := d.GetChallenge(context.Background(), chID, azID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -302,17 +300,15 @@ func TestDB_GetChallenge(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, ch.ID, tc.dbc.ID)
assert.Equals(t, ch.ID, tc.dbc.ID) assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
assert.Equals(t, ch.AccountID, tc.dbc.AccountID) assert.Equals(t, ch.Type, tc.dbc.Type)
assert.Equals(t, ch.Type, tc.dbc.Type) assert.Equals(t, ch.Status, tc.dbc.Status)
assert.Equals(t, ch.Status, tc.dbc.Status) assert.Equals(t, ch.Token, tc.dbc.Token)
assert.Equals(t, ch.Token, tc.dbc.Token) assert.Equals(t, ch.Value, tc.dbc.Value)
assert.Equals(t, ch.Value, tc.dbc.Value) assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt) assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
}
} }
}) })
} }
@ -442,8 +438,8 @@ func TestDB_UpdateChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil { if err := d.UpdateChallenge(context.Background(), tc.ch); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -31,7 +31,7 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
ID: id, ID: id,
CreatedAt: clock.Now(), CreatedAt: clock.Now(),
} }
if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil { if err := db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
return "", err return "", err
} }
return acme.Nonce(id), nil return acme.Nonce(id), nil

View file

@ -67,8 +67,8 @@ func TestDB_CreateNonce(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if n, err := db.CreateNonce(context.Background()); err != nil { if n, err := d.CreateNonce(context.Background()); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -144,8 +144,8 @@ func TestDB_DeleteNonce(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil { if err := d.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {

View file

@ -41,7 +41,7 @@ func New(db nosqlDB.DB) (*DB, error) {
// save writes the new data to the database, overwriting the old data if it // save writes the new data to the database, overwriting the old data if it
// existed. // existed.
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
var ( var (
err error err error
newB []byte newB []byte

View file

@ -126,8 +126,8 @@ func TestDB_save(t *testing.T) {
} }
for name, tc := range tests { for name, tc := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := &DB{db: tc.db} d := &DB{db: tc.db}
if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil { if err := d.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }

View file

@ -13,7 +13,6 @@ import (
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"github.com/smallstep/nosql/database" "github.com/smallstep/nosql/database"
nosqldb "github.com/smallstep/nosql/database"
) )
func TestDB_getDBOrder(t *testing.T) { func TestDB_getDBOrder(t *testing.T) {
@ -32,7 +31,7 @@ func TestDB_getDBOrder(t *testing.T) {
assert.Equals(t, bucket, orderTable) assert.Equals(t, bucket, orderTable)
assert.Equals(t, string(key), orderID) assert.Equals(t, string(key), orderID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
@ -101,8 +100,8 @@ func TestDB_getDBOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if dbo, err := db.getDBOrder(context.Background(), orderID); err != nil { if dbo, err := d.getDBOrder(context.Background(), orderID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -117,20 +116,18 @@ func TestDB_getDBOrder(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, dbo.ID, tc.dbo.ID)
assert.Equals(t, dbo.ID, tc.dbo.ID) assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID)
assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID) assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID)
assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID) assert.Equals(t, dbo.Status, tc.dbo.Status)
assert.Equals(t, dbo.Status, tc.dbo.Status) assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt)
assert.Equals(t, dbo.CreatedAt, tc.dbo.CreatedAt) assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt)
assert.Equals(t, dbo.ExpiresAt, tc.dbo.ExpiresAt) assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore)
assert.Equals(t, dbo.NotBefore, tc.dbo.NotBefore) assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter)
assert.Equals(t, dbo.NotAfter, tc.dbo.NotAfter) assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers)
assert.Equals(t, dbo.Identifiers, tc.dbo.Identifiers) assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs)
assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs) assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error())
assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error())
}
} }
}) })
} }
@ -165,7 +162,7 @@ func TestDB_GetOrder(t *testing.T) {
assert.Equals(t, bucket, orderTable) assert.Equals(t, bucket, orderTable)
assert.Equals(t, string(key), orderID) assert.Equals(t, string(key), orderID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"), acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
@ -207,8 +204,8 @@ func TestDB_GetOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if o, err := db.GetOrder(context.Background(), orderID); err != nil { if o, err := d.GetOrder(context.Background(), orderID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
if assert.NotNil(t, tc.acmeErr) { if assert.NotNil(t, tc.acmeErr) {
@ -223,20 +220,18 @@ func TestDB_GetOrder(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, o.ID, tc.dbo.ID)
assert.Equals(t, o.ID, tc.dbo.ID) assert.Equals(t, o.AccountID, tc.dbo.AccountID)
assert.Equals(t, o.AccountID, tc.dbo.AccountID) assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID)
assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID) assert.Equals(t, o.CertificateID, tc.dbo.CertificateID)
assert.Equals(t, o.CertificateID, tc.dbo.CertificateID) assert.Equals(t, o.Status, tc.dbo.Status)
assert.Equals(t, o.Status, tc.dbo.Status) assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt)
assert.Equals(t, o.ExpiresAt, tc.dbo.ExpiresAt) assert.Equals(t, o.NotBefore, tc.dbo.NotBefore)
assert.Equals(t, o.NotBefore, tc.dbo.NotBefore) assert.Equals(t, o.NotAfter, tc.dbo.NotAfter)
assert.Equals(t, o.NotAfter, tc.dbo.NotAfter) assert.Equals(t, o.Identifiers, tc.dbo.Identifiers)
assert.Equals(t, o.Identifiers, tc.dbo.Identifiers) assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs)
assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs) assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error())
assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error())
}
} }
}) })
} }
@ -367,8 +362,8 @@ func TestDB_UpdateOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.UpdateOrder(context.Background(), tc.o); err != nil { if err := d.UpdateOrder(context.Background(), tc.o); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -512,7 +507,7 @@ func TestDB_CreateOrder(t *testing.T) {
MGet: func(bucket, key []byte) ([]byte, error) { MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, string(bucket), string(ordersByAccountIDTable)) assert.Equals(t, string(bucket), string(ordersByAccountIDTable))
assert.Equals(t, string(key), o.AccountID) assert.Equals(t, string(key), o.AccountID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
switch string(bucket) { switch string(bucket) {
@ -558,8 +553,8 @@ func TestDB_CreateOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if err := db.CreateOrder(context.Background(), tc.o); err != nil { if err := d.CreateOrder(context.Background(), tc.o); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
@ -681,7 +676,7 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
MGet: func(bucket, key []byte) ([]byte, error) { MGet: func(bucket, key []byte) ([]byte, error) {
assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, bucket, ordersByAccountIDTable)
assert.Equals(t, key, []byte(accID)) assert.Equals(t, key, []byte(accID))
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) { MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, bucket, ordersByAccountIDTable)
@ -996,15 +991,15 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
var ( var (
res []string res []string
err error err error
) )
if tc.addOids == nil { if tc.addOids == nil {
res, err = db.updateAddOrderIDs(context.Background(), accID) res, err = d.updateAddOrderIDs(context.Background(), accID)
} else { } else {
res, err = db.updateAddOrderIDs(context.Background(), accID, tc.addOids...) res, err = d.updateAddOrderIDs(context.Background(), accID, tc.addOids...)
} }
if err != nil { if err != nil {
@ -1022,10 +1017,8 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.True(t, reflect.DeepEqual(res, tc.res))
assert.True(t, reflect.DeepEqual(res, tc.res))
}
} }
}) })
} }

View file

@ -289,6 +289,7 @@ func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.Certificate
// name or in an extensionRequest attribute [RFC2985] requesting a // name or in an extensionRequest attribute [RFC2985] requesting a
// subjectAltName extension, or both. // subjectAltName extension, or both.
if csr.Subject.CommonName != "" { if csr.Subject.CommonName != "" {
// nolint:gocritic
canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName) canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
} }
canonicalized.DNSNames = uniqueSortedLowerNames(csr.DNSNames) canonicalized.DNSNames = uniqueSortedLowerNames(csr.DNSNames)

View file

@ -240,9 +240,9 @@ type caHandler struct {
} }
// New creates a new RouterHandler with the CA endpoints. // New creates a new RouterHandler with the CA endpoints.
func New(authority Authority) RouterHandler { func New(auth Authority) RouterHandler {
return &caHandler{ return &caHandler{
Authority: authority, Authority: auth,
} }
} }
@ -295,7 +295,7 @@ func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
// certificate for the given SHA256. // certificate for the given SHA256.
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
sha := chi.URLParam(r, "sha") sha := chi.URLParam(r, "sha")
sum := strings.ToLower(strings.Replace(sha, "-", "", -1)) sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
// Load root certificate with the // Load root certificate with the
cert, err := h.Authority.Root(sum) cert, err := h.Authority.Root(sum)
if err != nil { if err != nil {
@ -409,19 +409,20 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
"certificate": base64.StdEncoding.EncodeToString(cert.Raw), "certificate": base64.StdEncoding.EncodeToString(cert.Raw),
} }
for _, ext := range cert.Extensions { for _, ext := range cert.Extensions {
if ext.Id.Equal(oidStepProvisioner) { if !ext.Id.Equal(oidStepProvisioner) {
val := &stepProvisioner{} continue
rest, err := asn1.Unmarshal(ext.Value, val) }
if err != nil || len(rest) > 0 { val := &stepProvisioner{}
break rest, err := asn1.Unmarshal(ext.Value, val)
} if err != nil || len(rest) > 0 {
if len(val.CredentialID) > 0 {
m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID)
} else {
m["provisioner"] = string(val.Name)
}
break break
} }
if len(val.CredentialID) > 0 {
m["provisioner"] = fmt.Sprintf("%s (%s)", val.Name, val.CredentialID)
} else {
m["provisioner"] = string(val.Name)
}
break
} }
rl.WithFields(m) rl.WithFields(m)
} }

View file

@ -186,8 +186,8 @@ func TestCertificate_MarshalJSON(t *testing.T) {
}{ }{
{"nil", fields{Certificate: nil}, []byte("null"), false}, {"nil", fields{Certificate: nil}, []byte("null"), false},
{"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false}, {"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false},
{"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"`), false}, {"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"`), false},
{"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`), false}, {"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`), false},
} }
for _, tt := range tests { for _, tt := range tests {
@ -219,11 +219,11 @@ func TestCertificate_UnmarshalJSON(t *testing.T) {
{"invalid string", []byte(`"foobar"`), false, true}, {"invalid string", []byte(`"foobar"`), false, true},
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
{"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true},
{"invalid type", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), false, true}, {"invalid type", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), false, true},
{"empty string", []byte(`""`), false, false}, {"empty string", []byte(`""`), false, false},
{"json null", []byte(`null`), false, false}, {"json null", []byte(`null`), false, false},
{"valid root", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), true, false}, {"valid root", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), true, false},
{"valid cert", []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"`), true, false}, {"valid cert", []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"`), true, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -251,7 +251,7 @@ func TestCertificate_UnmarshalJSON_json(t *testing.T) {
{"empty crt (null)", `{"crt":null}`, false, false}, {"empty crt (null)", `{"crt":null}`, false, false},
{"empty crt (string)", `{"crt":""}`, false, false}, {"empty crt (string)", `{"crt":""}`, false, false},
{"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true}, {"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true},
{"valid crt", `{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"}`, true, false}, {"valid crt", `{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"}`, true, false},
} }
type request struct { type request struct {
@ -297,7 +297,7 @@ func TestCertificateRequest_MarshalJSON(t *testing.T) {
}{ }{
{"nil", fields{CertificateRequest: nil}, []byte("null"), false}, {"nil", fields{CertificateRequest: nil}, []byte("null"), false},
{"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false}, {"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false},
{"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `\n"`), false}, {"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `\n"`), false},
} }
for _, tt := range tests { for _, tt := range tests {
@ -329,10 +329,10 @@ func TestCertificateRequest_UnmarshalJSON(t *testing.T) {
{"invalid string", []byte(`"foobar"`), false, true}, {"invalid string", []byte(`"foobar"`), false, true},
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true}, {"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
{"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true}, {"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true},
{"invalid type", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), false, true}, {"invalid type", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), false, true},
{"empty string", []byte(`""`), false, false}, {"empty string", []byte(`""`), false, false},
{"json null", []byte(`null`), false, false}, {"json null", []byte(`null`), false, false},
{"valid csr", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), true, false}, {"valid csr", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), true, false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -360,7 +360,7 @@ func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) {
{"empty csr (null)", `{"csr":null}`, false, false}, {"empty csr (null)", `{"csr":null}`, false, false},
{"empty csr (string)", `{"csr":""}`, false, false}, {"empty csr (string)", `{"csr":""}`, false, false},
{"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true}, {"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true},
{"valid csr", `{"csr":"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"}`, true, false}, {"valid csr", `{"csr":"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"}`, true, false},
} }
type request struct { type request struct {
@ -739,7 +739,7 @@ func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token strin
return m.ret1.(bool), m.err return m.ret1.(bool), m.err
} }
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) { func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
if m.getSSHBastion != nil { if m.getSSHBastion != nil {
return m.getSSHBastion(ctx, user, hostname) return m.getSSHBastion(ctx, user, hostname)
} }
@ -816,7 +816,7 @@ func Test_caHandler_Root(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil) req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil)
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)) req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
expected := []byte(`{"ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`) expected := []byte(`{"ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -860,8 +860,8 @@ func Test_caHandler_Sign(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
expected1 := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
expected2 := []byte(`{"crt":"` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
tests := []struct { tests := []struct {
name string name string
@ -934,7 +934,7 @@ func Test_caHandler_Renew(t *testing.T) {
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, {"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
} }
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -995,7 +995,7 @@ func Test_caHandler_Rekey(t *testing.T) {
{"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest}, {"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest},
} }
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -1210,7 +1210,7 @@ func Test_caHandler_Roots(t *testing.T) {
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
} }
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -1256,7 +1256,7 @@ func Test_caHandler_Federation(t *testing.T) {
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden}, {"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
} }
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`) expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -50,12 +50,10 @@ func WriteError(w http.ResponseWriter, err error) {
rl.WithFields(map[string]interface{}{ rl.WithFields(map[string]interface{}{
"stack-trace": fmt.Sprintf("%+v", e), "stack-trace": fmt.Sprintf("%+v", e),
}) })
} else { } else if e, ok := cause.(errs.StackTracer); ok {
if e, ok := cause.(errs.StackTracer); ok { rl.WithFields(map[string]interface{}{
rl.WithFields(map[string]interface{}{ "stack-trace": fmt.Sprintf("%+v", e),
"stack-trace": fmt.Sprintf("%+v", e), })
})
}
} }
} }
} }

View file

@ -52,7 +52,7 @@ func (s *SSHSignRequest) Validate() error {
return errors.Errorf("unknown certType %s", s.CertType) return errors.Errorf("unknown certType %s", s.CertType)
case len(s.PublicKey) == 0: case len(s.PublicKey) == 0:
return errors.New("missing or empty publicKey") return errors.New("missing or empty publicKey")
case len(s.OTT) == 0: case s.OTT == "":
return errors.New("missing or empty ott") return errors.New("missing or empty ott")
default: default:
// Validate identity signature if provided // Validate identity signature if provided
@ -408,18 +408,18 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
var config SSHConfigResponse var cfg SSHConfigResponse
switch body.Type { switch body.Type {
case provisioner.SSHUserCert: case provisioner.SSHUserCert:
config.UserTemplates = ts cfg.UserTemplates = ts
case provisioner.SSHHostCert: case provisioner.SSHHostCert:
config.HostTemplates = ts cfg.HostTemplates = ts
default: default:
WriteError(w, errs.InternalServer("it should hot get here")) WriteError(w, errs.InternalServer("it should hot get here"))
return return
} }
JSON(w, config) JSON(w, cfg)
} }
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.

View file

@ -19,7 +19,7 @@ type SSHRekeyRequest struct {
// Validate validates the SSHSignRekey. // Validate validates the SSHSignRekey.
func (s *SSHRekeyRequest) Validate() error { func (s *SSHRekeyRequest) Validate() error {
switch { switch {
case len(s.OTT) == 0: case s.OTT == "":
return errors.New("missing or empty ott") return errors.New("missing or empty ott")
case len(s.PublicKey) == 0: case len(s.PublicKey) == 0:
return errors.New("missing or empty public key") return errors.New("missing or empty public key")

View file

@ -18,7 +18,7 @@ type SSHRenewRequest struct {
// Validate validates the SSHSignRequest. // Validate validates the SSHSignRequest.
func (s *SSHRenewRequest) Validate() error { func (s *SSHRenewRequest) Validate() error {
switch { switch {
case len(s.OTT) == 0: case s.OTT == "":
return errors.New("missing or empty ott") return errors.New("missing or empty ott")
default: default:
return nil return nil

View file

@ -36,7 +36,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
if !r.Passive { if !r.Passive {
return errs.NotImplemented("non-passive revocation not implemented") return errs.NotImplemented("non-passive revocation not implemented")
} }
if len(r.OTT) == 0 { if r.OTT == "" {
return errs.BadRequest("missing ott") return errs.BadRequest("missing ott")
} }
return return

View file

@ -284,7 +284,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
identityCerts := []*x509.Certificate{ identityCerts := []*x509.Certificate{
parseCertificate(certPEM), parseCertificate(certPEM),
} }
identityCertsPEM := []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`) identityCertsPEM := []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`)
tests := []struct { tests := []struct {
name string name string

View file

@ -27,7 +27,7 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization") tok := r.Header.Get("Authorization")
if len(tok) == 0 { if tok == "" {
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType, api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
"missing authorization header token")) "missing authorization header token"))
return return

View file

@ -12,7 +12,6 @@ import (
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"github.com/smallstep/nosql/database" "github.com/smallstep/nosql/database"
nosqldb "github.com/smallstep/nosql/database"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
) )
@ -32,7 +31,7 @@ func TestDB_getDBAdminBytes(t *testing.T) {
assert.Equals(t, bucket, adminsTable) assert.Equals(t, bucket, adminsTable)
assert.Equals(t, string(key), adminID) assert.Equals(t, string(key), adminID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
@ -67,8 +66,8 @@ func TestDB_getDBAdminBytes(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if b, err := db.getDBAdminBytes(context.Background(), adminID); err != nil { if b, err := d.getDBAdminBytes(context.Background(), adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -83,10 +82,8 @@ func TestDB_getDBAdminBytes(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.Equals(t, string(b), "foo")
assert.Equals(t, string(b), "foo")
}
} }
}) })
} }
@ -108,7 +105,7 @@ func TestDB_getDBAdmin(t *testing.T) {
assert.Equals(t, bucket, adminsTable) assert.Equals(t, bucket, adminsTable)
assert.Equals(t, string(key), adminID) assert.Equals(t, string(key), adminID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
@ -193,8 +190,8 @@ func TestDB_getDBAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if dba, err := db.getDBAdmin(context.Background(), adminID); err != nil { if dba, err := d.getDBAdmin(context.Background(), adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -209,16 +206,14 @@ func TestDB_getDBAdmin(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dba.ID, adminID)
assert.Equals(t, dba.ID, adminID) assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID) assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID) assert.Equals(t, dba.Subject, tc.dba.Subject)
assert.Equals(t, dba.Subject, tc.dba.Subject) assert.Equals(t, dba.Type, tc.dba.Type)
assert.Equals(t, dba.Type, tc.dba.Type) assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt) assert.Fatal(t, dba.DeletedAt.IsZero())
assert.Fatal(t, dba.DeletedAt.IsZero())
}
} }
}) })
} }
@ -283,8 +278,8 @@ func TestDB_unmarshalDBAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{authorityID: admin.DefaultAuthorityID} d := DB{authorityID: admin.DefaultAuthorityID}
if dba, err := db.unmarshalDBAdmin(tc.in, adminID); err != nil { if dba, err := d.unmarshalDBAdmin(tc.in, adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -299,16 +294,14 @@ func TestDB_unmarshalDBAdmin(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dba.ID, adminID)
assert.Equals(t, dba.ID, adminID) assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID) assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID) assert.Equals(t, dba.Subject, tc.dba.Subject)
assert.Equals(t, dba.Subject, tc.dba.Subject) assert.Equals(t, dba.Type, tc.dba.Type)
assert.Equals(t, dba.Type, tc.dba.Type) assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt) assert.Fatal(t, dba.DeletedAt.IsZero())
assert.Fatal(t, dba.DeletedAt.IsZero())
}
} }
}) })
} }
@ -360,8 +353,8 @@ func TestDB_unmarshalAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{authorityID: admin.DefaultAuthorityID} d := DB{authorityID: admin.DefaultAuthorityID}
if adm, err := db.unmarshalAdmin(tc.in, adminID); err != nil { if adm, err := d.unmarshalAdmin(tc.in, adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -376,16 +369,14 @@ func TestDB_unmarshalAdmin(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, adm.Id, adminID)
assert.Equals(t, adm.Id, adminID) assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID) assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID) assert.Equals(t, adm.Subject, tc.dba.Subject)
assert.Equals(t, adm.Subject, tc.dba.Subject) assert.Equals(t, adm.Type, tc.dba.Type)
assert.Equals(t, adm.Type, tc.dba.Type) assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt)) assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
}
} }
}) })
} }
@ -407,7 +398,7 @@ func TestDB_GetAdmin(t *testing.T) {
assert.Equals(t, bucket, adminsTable) assert.Equals(t, bucket, adminsTable)
assert.Equals(t, string(key), adminID) assert.Equals(t, string(key), adminID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
@ -516,8 +507,8 @@ func TestDB_GetAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if adm, err := db.GetAdmin(context.Background(), adminID); err != nil { if adm, err := d.GetAdmin(context.Background(), adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -532,16 +523,14 @@ func TestDB_GetAdmin(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, adm.Id, adminID)
assert.Equals(t, adm.Id, adminID) assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID) assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID) assert.Equals(t, adm.Subject, tc.dba.Subject)
assert.Equals(t, adm.Subject, tc.dba.Subject) assert.Equals(t, adm.Type, tc.dba.Type)
assert.Equals(t, adm.Type, tc.dba.Type) assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt)) assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
}
} }
}) })
} }
@ -562,7 +551,7 @@ func TestDB_DeleteAdmin(t *testing.T) {
assert.Equals(t, bucket, adminsTable) assert.Equals(t, bucket, adminsTable)
assert.Equals(t, string(key), adminID) assert.Equals(t, string(key), adminID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
@ -670,8 +659,8 @@ func TestDB_DeleteAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.DeleteAdmin(context.Background(), adminID); err != nil { if err := d.DeleteAdmin(context.Background(), adminID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -708,7 +697,7 @@ func TestDB_UpdateAdmin(t *testing.T) {
assert.Equals(t, bucket, adminsTable) assert.Equals(t, bucket, adminsTable)
assert.Equals(t, string(key), adminID) assert.Equals(t, string(key), adminID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
@ -821,8 +810,8 @@ func TestDB_UpdateAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.UpdateAdmin(context.Background(), tc.adm); err != nil { if err := d.UpdateAdmin(context.Background(), tc.adm); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -919,8 +908,8 @@ func TestDB_CreateAdmin(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.CreateAdmin(context.Background(), tc.adm); err != nil { if err := d.CreateAdmin(context.Background(), tc.adm); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -1095,8 +1084,8 @@ func TestDB_GetAdmins(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if admins, err := db.GetAdmins(context.Background()); err != nil { if admins, err := d.GetAdmins(context.Background()); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -1111,10 +1100,8 @@ func TestDB_GetAdmins(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { tc.verify(t, admins)
tc.verify(t, admins)
}
} }
}) })
} }

View file

@ -35,7 +35,7 @@ func New(db nosqlDB.DB, authorityID string) (*DB, error) {
// save writes the new data to the database, overwriting the old data if it // save writes the new data to the database, overwriting the old data if it
// existed. // existed.
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error { func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
var ( var (
err error err error
newB []byte newB []byte

View file

@ -12,7 +12,6 @@ import (
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"github.com/smallstep/nosql/database" "github.com/smallstep/nosql/database"
nosqldb "github.com/smallstep/nosql/database"
"go.step.sm/linkedca" "go.step.sm/linkedca"
) )
@ -31,7 +30,7 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
assert.Equals(t, bucket, provisionersTable) assert.Equals(t, bucket, provisionersTable)
assert.Equals(t, string(key), provID) assert.Equals(t, string(key), provID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
@ -66,8 +65,8 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db} d := DB{db: tc.db}
if b, err := db.getDBProvisionerBytes(context.Background(), provID); err != nil { if b, err := d.getDBProvisionerBytes(context.Background(), provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -82,10 +81,8 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, string(b), "foo")
assert.Equals(t, string(b), "foo")
}
} }
}) })
} }
@ -107,7 +104,7 @@ func TestDB_getDBProvisioner(t *testing.T) {
assert.Equals(t, bucket, provisionersTable) assert.Equals(t, bucket, provisionersTable)
assert.Equals(t, string(key), provID) assert.Equals(t, string(key), provID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
@ -190,8 +187,8 @@ func TestDB_getDBProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if dbp, err := db.getDBProvisioner(context.Background(), provID); err != nil { if dbp, err := d.getDBProvisioner(context.Background(), provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -206,15 +203,13 @@ func TestDB_getDBProvisioner(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dbp.ID, provID)
assert.Equals(t, dbp.ID, provID) assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID) assert.Equals(t, dbp.Type, tc.dbp.Type)
assert.Equals(t, dbp.Type, tc.dbp.Type) assert.Equals(t, dbp.Name, tc.dbp.Name)
assert.Equals(t, dbp.Name, tc.dbp.Name) assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt) assert.Fatal(t, dbp.DeletedAt.IsZero())
assert.Fatal(t, dbp.DeletedAt.IsZero())
}
} }
}) })
} }
@ -278,8 +273,8 @@ func TestDB_unmarshalDBProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{authorityID: admin.DefaultAuthorityID} d := DB{authorityID: admin.DefaultAuthorityID}
if dbp, err := db.unmarshalDBProvisioner(tc.in, provID); err != nil { if dbp, err := d.unmarshalDBProvisioner(tc.in, provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -294,19 +289,17 @@ func TestDB_unmarshalDBProvisioner(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, dbp.ID, provID)
assert.Equals(t, dbp.ID, provID) assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID) assert.Equals(t, dbp.Type, tc.dbp.Type)
assert.Equals(t, dbp.Type, tc.dbp.Type) assert.Equals(t, dbp.Name, tc.dbp.Name)
assert.Equals(t, dbp.Name, tc.dbp.Name) assert.Equals(t, dbp.Details, tc.dbp.Details)
assert.Equals(t, dbp.Details, tc.dbp.Details) assert.Equals(t, dbp.Claims, tc.dbp.Claims)
assert.Equals(t, dbp.Claims, tc.dbp.Claims) assert.Equals(t, dbp.X509Template, tc.dbp.X509Template)
assert.Equals(t, dbp.X509Template, tc.dbp.X509Template) assert.Equals(t, dbp.SSHTemplate, tc.dbp.SSHTemplate)
assert.Equals(t, dbp.SSHTemplate, tc.dbp.SSHTemplate) assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt) assert.Fatal(t, dbp.DeletedAt.IsZero())
assert.Fatal(t, dbp.DeletedAt.IsZero())
}
} }
}) })
} }
@ -402,8 +395,8 @@ func TestDB_unmarshalProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{authorityID: admin.DefaultAuthorityID} d := DB{authorityID: admin.DefaultAuthorityID}
if prov, err := db.unmarshalProvisioner(tc.in, provID); err != nil { if prov, err := d.unmarshalProvisioner(tc.in, provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -418,20 +411,18 @@ func TestDB_unmarshalProvisioner(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, prov.Id, provID)
assert.Equals(t, prov.Id, provID) assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID) assert.Equals(t, prov.Type, tc.dbp.Type)
assert.Equals(t, prov.Type, tc.dbp.Type) assert.Equals(t, prov.Name, tc.dbp.Name)
assert.Equals(t, prov.Name, tc.dbp.Name) assert.Equals(t, prov.Claims, tc.dbp.Claims)
assert.Equals(t, prov.Claims, tc.dbp.Claims) assert.Equals(t, prov.X509Template, tc.dbp.X509Template)
assert.Equals(t, prov.X509Template, tc.dbp.X509Template) assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
retDetailsBytes, err := json.Marshal(prov.Details.GetData()) retDetailsBytes, err := json.Marshal(prov.Details.GetData())
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, retDetailsBytes, tc.dbp.Details) assert.Equals(t, retDetailsBytes, tc.dbp.Details)
}
} }
}) })
} }
@ -453,7 +444,7 @@ func TestDB_GetProvisioner(t *testing.T) {
assert.Equals(t, bucket, provisionersTable) assert.Equals(t, bucket, provisionersTable)
assert.Equals(t, string(key), provID) assert.Equals(t, string(key), provID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
@ -542,8 +533,8 @@ func TestDB_GetProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if prov, err := db.GetProvisioner(context.Background(), provID); err != nil { if prov, err := d.GetProvisioner(context.Background(), provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -558,20 +549,18 @@ func TestDB_GetProvisioner(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { assert.Equals(t, prov.Id, provID)
assert.Equals(t, prov.Id, provID) assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID) assert.Equals(t, prov.Type, tc.dbp.Type)
assert.Equals(t, prov.Type, tc.dbp.Type) assert.Equals(t, prov.Name, tc.dbp.Name)
assert.Equals(t, prov.Name, tc.dbp.Name) assert.Equals(t, prov.Claims, tc.dbp.Claims)
assert.Equals(t, prov.Claims, tc.dbp.Claims) assert.Equals(t, prov.X509Template, tc.dbp.X509Template)
assert.Equals(t, prov.X509Template, tc.dbp.X509Template) assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
assert.Equals(t, prov.SshTemplate, tc.dbp.SSHTemplate)
retDetailsBytes, err := json.Marshal(prov.Details.GetData()) retDetailsBytes, err := json.Marshal(prov.Details.GetData())
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, retDetailsBytes, tc.dbp.Details) assert.Equals(t, retDetailsBytes, tc.dbp.Details)
}
} }
}) })
} }
@ -592,7 +581,7 @@ func TestDB_DeleteProvisioner(t *testing.T) {
assert.Equals(t, bucket, provisionersTable) assert.Equals(t, bucket, provisionersTable)
assert.Equals(t, string(key), provID) assert.Equals(t, string(key), provID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
@ -692,8 +681,8 @@ func TestDB_DeleteProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.DeleteProvisioner(context.Background(), provID); err != nil { if err := d.DeleteProvisioner(context.Background(), provID); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -853,8 +842,8 @@ func TestDB_GetProvisioners(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if provs, err := db.GetProvisioners(context.Background()); err != nil { if provs, err := d.GetProvisioners(context.Background()); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -869,10 +858,8 @@ func TestDB_GetProvisioners(t *testing.T) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} }
} else { } else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) { tc.verify(t, provs)
tc.verify(t, provs)
}
} }
}) })
} }
@ -963,8 +950,8 @@ func TestDB_CreateProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.CreateProvisioner(context.Background(), tc.prov); err != nil { if err := d.CreateProvisioner(context.Background(), tc.prov); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {
@ -1001,7 +988,7 @@ func TestDB_UpdateProvisioner(t *testing.T) {
assert.Equals(t, bucket, provisionersTable) assert.Equals(t, bucket, provisionersTable)
assert.Equals(t, string(key), provID) assert.Equals(t, string(key), provID)
return nil, nosqldb.ErrNotFound return nil, database.ErrNotFound
}, },
}, },
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"), adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
@ -1199,8 +1186,8 @@ func TestDB_UpdateProvisioner(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID} d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
if err := db.UpdateProvisioner(context.Background(), tc.prov); err != nil { if err := d.UpdateProvisioner(context.Background(), tc.prov); err != nil {
switch k := err.(type) { switch k := err.(type) {
case *admin.Error: case *admin.Error:
if assert.NotNil(t, tc.adminErr) { if assert.NotNil(t, tc.adminErr) {

View file

@ -55,8 +55,8 @@ type subProv struct {
provisioner string provisioner string
} }
func newSubProv(subject, provisioner string) subProv { func newSubProv(subject, prov string) subProv {
return subProv{subject, provisioner} return subProv{subject, prov}
} }
// LoadBySubProv a admin by the subject and provisioner name. // LoadBySubProv a admin by the subject and provisioner name.

View file

@ -16,10 +16,10 @@ func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
} }
// LoadAdminBySubProv returns an *linkedca.Admin with the given ID. // LoadAdminBySubProv returns an *linkedca.Admin with the given ID.
func (a *Authority) LoadAdminBySubProv(subject, provisioner string) (*linkedca.Admin, bool) { func (a *Authority) LoadAdminBySubProv(subject, prov string) (*linkedca.Admin, bool) {
a.adminMutex.RLock() a.adminMutex.RLock()
defer a.adminMutex.RUnlock() defer a.adminMutex.RUnlock()
return a.admins.LoadBySubProv(subject, provisioner) return a.admins.LoadBySubProv(subject, prov)
} }
// GetAdmins returns a map listing each provisioner and the JWK Key Set // GetAdmins returns a map listing each provisioner and the JWK Key Set

View file

@ -78,14 +78,14 @@ type Authority struct {
} }
// New creates and initiates a new Authority type. // New creates and initiates a new Authority type.
func New(config *config.Config, opts ...Option) (*Authority, error) { func New(cfg *config.Config, opts ...Option) (*Authority, error) {
err := config.Validate() err := cfg.Validate()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var a = &Authority{ var a = &Authority{
config: config, config: cfg,
certificates: new(sync.Map), certificates: new(sync.Map),
} }

View file

@ -54,7 +54,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// key in order to verify the claims and we need the issuer from the claims // key in order to verify the claims and we need the issuer from the claims
// before we can look up the provisioner. // before we can look up the provisioner.
var claims Claims var claims Claims
if err = tok.UnsafeClaimsWithoutVerification(&claims); err != nil { if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken") return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
} }
@ -77,7 +77,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// Store the token to protect against reuse unless it's skipped. // Store the token to protect against reuse unless it's skipped.
// If we cannot get a token id from the provisioner, just hash the token. // If we cannot get a token id from the provisioner, just hash the token.
if !SkipTokenReuseFromContext(ctx) { if !SkipTokenReuseFromContext(ctx) {
if err = a.UseToken(token, p); err != nil { if err := a.UseToken(token, p); err != nil {
return nil, err return nil, err
} }
} }
@ -112,7 +112,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
// to the public certificate in the `x5c` header of the token. // to the public certificate in the `x5c` header of the token.
// 2. Asserts that the claims are valid - have not been tampered with. // 2. Asserts that the claims are valid - have not been tampered with.
var claims jose.Claims var claims jose.Claims
if err = jwt.Claims(leaf.PublicKey, &claims); err != nil { if err := jwt.Claims(leaf.PublicKey, &claims); err != nil {
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c claims") return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c claims")
} }
@ -122,13 +122,13 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
} }
// Check that the token has not been used. // Check that the token has not been used.
if err = a.UseToken(token, prov); err != nil { if err := a.UseToken(token, prov); err != nil {
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token") return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token")
} }
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
// more than a few minutes. // more than a few minutes.
if err = claims.ValidateWithLeeway(jose.Expected{ if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: prov.GetName(), Issuer: prov.GetName(),
Time: time.Now().UTC(), Time: time.Now().UTC(),
}, time.Minute); err != nil { }, time.Minute); err != nil {
@ -262,7 +262,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
} }
if err = p.AuthorizeRevoke(ctx, token); err != nil { if err := p.AuthorizeRevoke(ctx, token); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
} }
return nil return nil

View file

@ -917,7 +917,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if err = cert.SignCert(rand.Reader, signer); err != nil { if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err return nil, nil, err
} }
return cert, jwk, nil return cert, jwk, nil

View file

@ -25,7 +25,7 @@ func (s multiString) HasEmpties() bool {
return true return true
} }
for _, ss := range s { for _, ss := range s {
if len(ss) == 0 { if ss == "" {
return true return true
} }
} }

View file

@ -272,12 +272,12 @@ func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificat
return errors.Wrap(err, "error revoking certificate") return errors.Wrap(err, "error revoking certificate")
} }
func (c *linkedCaClient) RevokeSSH(ssh *ssh.Certificate, rci *db.RevokedCertificateInfo) error { func (c *linkedCaClient) RevokeSSH(cert *ssh.Certificate, rci *db.RevokedCertificateInfo) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel() defer cancel()
_, err := c.client.RevokeSSHCertificate(ctx, &linkedca.RevokeSSHCertificateRequest{ _, err := c.client.RevokeSSHCertificate(ctx, &linkedca.RevokeSSHCertificateRequest{
Serial: rci.Serial, Serial: rci.Serial,
Certificate: serializeSSHCertificate(ssh), Certificate: serializeSSHCertificate(cert),
Reason: rci.Reason, Reason: rci.Reason,
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode), ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
Passive: true, Passive: true,

View file

@ -22,9 +22,9 @@ type Option func(*Authority) error
// WithConfig replaces the current config with the given one. No validation is // WithConfig replaces the current config with the given one. No validation is
// performed in the given value. // performed in the given value.
func WithConfig(config *config.Config) Option { func WithConfig(cfg *config.Config) Option {
return func(a *Authority) error { return func(a *Authority) error {
a.config = config a.config = cfg
return nil return nil
} }
} }
@ -76,9 +76,9 @@ func WithIssuerPassword(password []byte) Option {
// WithDatabase sets an already initialized authority database to a new // WithDatabase sets an already initialized authority database to a new
// authority. This option is intended to be use on graceful reloads. // authority. This option is intended to be use on graceful reloads.
func WithDatabase(db db.AuthDB) Option { func WithDatabase(d db.AuthDB) Option {
return func(a *Authority) error { return func(a *Authority) error {
a.db = db a.db = d
return nil return nil
} }
} }
@ -225,9 +225,9 @@ func WithX509FederatedBundle(pemCerts []byte) Option {
} }
// WithAdminDB is an option to set the database backing the admin APIs. // WithAdminDB is an option to set the database backing the admin APIs.
func WithAdminDB(db admin.DB) Option { func WithAdminDB(d admin.DB) Option {
return func(a *Authority) error { return func(a *Authority) error {
a.adminDB = db a.adminDB = d
return nil return nil
} }
} }

View file

@ -312,7 +312,7 @@ func (p *AWS) GetType() Type {
} }
// GetEncryptedKey is not available in an AWS provisioner. // GetEncryptedKey is not available in an AWS provisioner.
func (p *AWS) GetEncryptedKey() (kid string, key string, ok bool) { func (p *AWS) GetEncryptedKey() (kid, key string, ok bool) {
return "", "", false return "", "", false
} }
@ -449,13 +449,15 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// There's no way to trust them other than TOFU. // There's no way to trust them other than TOFU.
var so []SignOption var so []SignOption
if p.DisableCustomSANs { if p.DisableCustomSANs {
dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region)
so = append(so, dnsNamesValidator([]string{dnsName})) so = append(so,
so = append(so, ipAddressesValidator([]net.IP{ dnsNamesValidator([]string{dnsName}),
net.ParseIP(doc.PrivateIP), ipAddressesValidator([]net.IP{
})) net.ParseIP(doc.PrivateIP),
so = append(so, emailAddressesValidator(nil)) }),
so = append(so, urisValidator(nil)) emailAddressesValidator(nil),
urisValidator(nil),
)
// Template options // Template options
data.SetSANs([]string{dnsName, doc.PrivateIP}) data.SetSANs([]string{dnsName, doc.PrivateIP})
@ -669,7 +671,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
if p.DisableCustomSANs { if p.DisableCustomSANs {
if payload.Subject != doc.InstanceID && if payload.Subject != doc.InstanceID &&
payload.Subject != doc.PrivateIP && payload.Subject != doc.PrivateIP &&
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) { payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region) {
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)") return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
} }
} }
@ -720,7 +722,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validated principals. // Validated principals.
principals := []string{ principals := []string{
doc.PrivateIP, doc.PrivateIP,
fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region), fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region),
} }
// Only enforce known principals if disable custom sans is true. // Only enforce known principals if disable custom sans is true.

View file

@ -663,15 +663,15 @@ func TestAWS_AuthorizeSign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod) ctx := NewContextWithMethod(context.Background(), SignMethod)
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token) switch got, err := tt.aws.AuthorizeSign(ctx, tt.args.token); {
if (err != nil) != tt.wantErr { case (err != nil) != tt.wantErr:
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} else if err != nil { case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} else { default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {

View file

@ -152,7 +152,7 @@ func (p *Azure) GetType() Type {
} }
// GetEncryptedKey is not available in an Azure provisioner. // GetEncryptedKey is not available in an Azure provisioner.
func (p *Azure) GetEncryptedKey() (kid string, key string, ok bool) { func (p *Azure) GetEncryptedKey() (kid, key string, ok bool) {
return "", "", false return "", "", false
} }
@ -303,11 +303,13 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
var so []SignOption var so []SignOption
if p.DisableCustomSANs { if p.DisableCustomSANs {
// name will work only inside the virtual network // name will work only inside the virtual network
so = append(so, commonNameValidator(name)) so = append(so,
so = append(so, dnsNamesValidator([]string{name})) commonNameValidator(name),
so = append(so, ipAddressesValidator(nil)) dnsNamesValidator([]string{name}),
so = append(so, emailAddressesValidator(nil)) ipAddressesValidator(nil),
so = append(so, urisValidator(nil)) emailAddressesValidator(nil),
urisValidator(nil),
)
// Enforce SANs in the template. // Enforce SANs in the template.
data.SetSANs([]string{name}) data.SetSANs([]string{name})

View file

@ -446,15 +446,15 @@ func TestAzure_AuthorizeSign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod) ctx := NewContextWithMethod(context.Background(), SignMethod)
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token) switch got, err := tt.azure.AuthorizeSign(ctx, tt.args.token); {
if (err != nil) != tt.wantErr { case (err != nil) != tt.wantErr:
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} else if err != nil { case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} else { default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {

View file

@ -229,14 +229,15 @@ func (c *Collection) Remove(id string) error {
var found bool var found bool
for i, elem := range c.sorted { for i, elem := range c.sorted {
if elem.provisioner.GetID() == id { if elem.provisioner.GetID() != id {
// Remove index in sorted list continue
copy(c.sorted[i:], c.sorted[i+1:]) // Shift a[i+1:] left one index.
c.sorted[len(c.sorted)-1] = uidProvisioner{} // Erase last element (write zero value).
c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice.
found = true
break
} }
// Remove index in sorted list
copy(c.sorted[i:], c.sorted[i+1:]) // Shift a[i+1:] left one index.
c.sorted[len(c.sorted)-1] = uidProvisioner{} // Erase last element (write zero value).
c.sorted = c.sorted[:len(c.sorted)-1] // Truncate slice.
found = true
break
} }
if !found { if !found {
return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found in sorted list", prov.GetName()) return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found in sorted list", prov.GetName())

View file

@ -150,7 +150,7 @@ func (p *GCP) GetType() Type {
} }
// GetEncryptedKey is not available in a GCP provisioner. // GetEncryptedKey is not available in a GCP provisioner.
func (p *GCP) GetEncryptedKey() (kid string, key string, ok bool) { func (p *GCP) GetEncryptedKey() (kid, key string, ok bool) {
return "", "", false return "", "", false
} }
@ -244,15 +244,17 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
if p.DisableCustomSANs { if p.DisableCustomSANs {
dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID) dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID)
dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID) dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID)
so = append(so, commonNameSliceValidator([]string{ so = append(so,
ce.InstanceName, ce.InstanceID, dnsName1, dnsName2, commonNameSliceValidator([]string{
})) ce.InstanceName, ce.InstanceID, dnsName1, dnsName2,
so = append(so, dnsNamesValidator([]string{ }),
dnsName1, dnsName2, dnsNamesValidator([]string{
})) dnsName1, dnsName2,
so = append(so, ipAddressesValidator(nil)) }),
so = append(so, emailAddressesValidator(nil)) ipAddressesValidator(nil),
so = append(so, urisValidator(nil)) emailAddressesValidator(nil),
urisValidator(nil),
)
// Template SANs // Template SANs
data.SetSANs([]string{dnsName1, dnsName2}) data.SetSANs([]string{dnsName1, dnsName2})

View file

@ -535,15 +535,15 @@ func TestGCP_AuthorizeSign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := NewContextWithMethod(context.Background(), SignMethod) ctx := NewContextWithMethod(context.Background(), SignMethod)
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token) switch got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token); {
if (err != nil) != tt.wantErr { case (err != nil) != tt.wantErr:
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
return return
} else if err != nil { case err != nil:
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
} else { default:
assert.Len(t, tt.wantLen, got) assert.Len(t, tt.wantLen, got)
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {

View file

@ -18,7 +18,7 @@ const (
defaultCacheJitter = 1 * time.Hour defaultCacheJitter = 1 * time.Hour
) )
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]+)") var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`)
type keyStore struct { type keyStore struct {
sync.RWMutex sync.RWMutex

View file

@ -29,7 +29,7 @@ func (p *noop) GetType() Type {
return noopType return noopType
} }
func (p *noop) GetEncryptedKey() (kid string, key string, ok bool) { func (p *noop) GetEncryptedKey() (kid, key string, ok bool) {
return "", "", false return "", "", false
} }

View file

@ -148,7 +148,7 @@ func (o *OIDC) GetType() Type {
} }
// GetEncryptedKey is not available in an OIDC provisioner. // GetEncryptedKey is not available in an OIDC provisioner.
func (o *OIDC) GetEncryptedKey() (kid string, key string, ok bool) { func (o *OIDC) GetEncryptedKey() (kid, key string, ok bool) {
return "", "", false return "", "", false
} }
@ -193,7 +193,7 @@ func (o *OIDC) Init(config Config) (err error) {
} }
// Replace {tenantid} with the configured one // Replace {tenantid} with the configured one
if o.TenantID != "" { if o.TenantID != "" {
o.configuration.Issuer = strings.Replace(o.configuration.Issuer, "{tenantid}", o.TenantID, -1) o.configuration.Issuer = strings.ReplaceAll(o.configuration.Issuer, "{tenantid}", o.TenantID)
} }
// Get JWK key set // Get JWK key set
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI) o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)

View file

@ -321,32 +321,26 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
assert.Nil(t, got) assert.Nil(t, got)
} else { } else if assert.NotNil(t, got) {
if assert.NotNil(t, got) { assert.Len(t, 5, got)
if tt.name == "admin" { for _, o := range got {
assert.Len(t, 5, got) switch v := o.(type) {
} else { case certificateOptionsFunc:
assert.Len(t, 5, got) case *provisionerExtensionOption:
} assert.Equals(t, v.Type, int(TypeOIDC))
for _, o := range got { assert.Equals(t, v.Name, tt.prov.GetName())
switch v := o.(type) { assert.Equals(t, v.CredentialID, tt.prov.ClientID)
case certificateOptionsFunc: assert.Len(t, 0, v.KeyValuePairs)
case *provisionerExtensionOption: case profileDefaultDuration:
assert.Equals(t, v.Type, int(TypeOIDC)) assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
assert.Equals(t, v.Name, tt.prov.GetName()) case defaultPublicKeyValidator:
assert.Equals(t, v.CredentialID, tt.prov.ClientID) case *validityValidator:
assert.Len(t, 0, v.KeyValuePairs) assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
case profileDefaultDuration: assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) case emailOnlyIdentity:
case defaultPublicKeyValidator: assert.Equals(t, string(v), "name@smallstep.com")
case *validityValidator: default:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
case emailOnlyIdentity:
assert.Equals(t, string(v), "name@smallstep.com")
default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
}
} }
} }
} }

View file

@ -138,7 +138,7 @@ func unsafeParseSigned(s string) (map[string]interface{}, error) {
return nil, err return nil, err
} }
claims := make(map[string]interface{}) claims := make(map[string]interface{})
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil { if err := token.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, err return nil, err
} }
return claims, nil return claims, nil

View file

@ -123,7 +123,7 @@ func (a Audiences) WithFragment(fragment string) Audiences {
// generateSignAudience generates a sign audience with the format // generateSignAudience generates a sign audience with the format
// https://<host>/1.0/sign#provisionerID // https://<host>/1.0/sign#provisionerID
func generateSignAudience(caURL string, provisionerID string) (string, error) { func generateSignAudience(caURL, provisionerID string) (string, error) {
u, err := url.Parse(caURL) u, err := url.Parse(caURL)
if err != nil { if err != nil {
return "", errors.Wrapf(err, "error parsing %s", caURL) return "", errors.Wrapf(err, "error parsing %s", caURL)

View file

@ -44,7 +44,7 @@ func TestSSHOptions_Modify(t *testing.T) {
valid func(*ssh.Certificate) valid func(*ssh.Certificate)
err error err error
} }
tests := map[string](func() test){ tests := map[string]func() test{
"fail/unexpected-cert-type": func() test { "fail/unexpected-cert-type": func() test {
return test{ return test{
so: SignSSHOptions{CertType: "foo"}, so: SignSSHOptions{CertType: "foo"},
@ -117,7 +117,7 @@ func TestSSHOptions_Match(t *testing.T) {
cmp SignSSHOptions cmp SignSSHOptions
err error err error
} }
tests := map[string](func() test){ tests := map[string]func() test{
"fail/cert-type": func() test { "fail/cert-type": func() test {
return test{ return test{
so: SignSSHOptions{CertType: "foo"}, so: SignSSHOptions{CertType: "foo"},
@ -208,7 +208,7 @@ func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
cert *ssh.Certificate cert *ssh.Certificate
expected []string expected []string
} }
tests := map[string](func() test){ tests := map[string]func() test{
"ok": func() test { "ok": func() test {
a := []string{"foo", "bar"} a := []string{"foo", "bar"}
return test{ return test{
@ -234,7 +234,7 @@ func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
cert *ssh.Certificate cert *ssh.Certificate
expected string expected string
} }
tests := map[string](func() test){ tests := map[string]func() test{
"ok": func() test { "ok": func() test {
a := "foo" a := "foo"
return test{ return test{
@ -260,7 +260,7 @@ func Test_sshCertTypeModifier_Modify(t *testing.T) {
cert *ssh.Certificate cert *ssh.Certificate
expected uint32 expected uint32
} }
tests := map[string](func() test){ tests := map[string]func() test{
"ok/user": func() test { "ok/user": func() test {
return test{ return test{
modifier: sshCertTypeModifier("user"), modifier: sshCertTypeModifier("user"),
@ -299,7 +299,7 @@ func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
cert *ssh.Certificate cert *ssh.Certificate
expected uint64 expected uint64
} }
tests := map[string](func() test){ tests := map[string]func() test{
"ok": func() test { "ok": func() test {
return test{ return test{
modifier: sshCertValidAfterModifier(15), modifier: sshCertValidAfterModifier(15),
@ -324,7 +324,7 @@ func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
cert *ssh.Certificate cert *ssh.Certificate
valid func(*ssh.Certificate) valid func(*ssh.Certificate)
} }
tests := map[string](func() test){ tests := map[string]func() test{
"ok/changes": func() test { "ok/changes": func() test {
n := time.Now() n := time.Now()
va := NewTimeDuration(n.Add(1 * time.Minute)) va := NewTimeDuration(n.Add(1 * time.Minute))
@ -388,7 +388,7 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
valid func(*ssh.Certificate) valid func(*ssh.Certificate)
err error err error
} }
tests := map[string](func() test){ tests := map[string]func() test{
"fail/unexpected-cert-type": func() test { "fail/unexpected-cert-type": func() test {
cert := &ssh.Certificate{CertType: 3} cert := &ssh.Certificate{CertType: 3}
return test{ return test{

View file

@ -46,7 +46,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if err = cert.SignCert(rand.Reader, signer); err != nil { if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err return nil, nil, err
} }
return cert, jwk, nil return cert, jwk, nil
@ -214,10 +214,8 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
assert.HasPrefix(t, err.Error(), tc.err.Error()) assert.HasPrefix(t, err.Error(), tc.err.Error())
} }
} else { } else if assert.Nil(t, tc.err) {
if assert.Nil(t, tc.err) { assert.NotNil(t, claims)
assert.NotNil(t, claims)
}
} }
}) })
} }

View file

@ -732,7 +732,7 @@ func withSSHPOPFile(cert *ssh.Certificate) tokOption {
} }
} }
func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { func generateToken(sub, iss, aud, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
so := new(jose.SignerOptions) so := new(jose.SignerOptions)
so.WithType("JWT") so.WithType("JWT")
so.WithHeader("kid", jwk.KeyID) so.WithHeader("kid", jwk.KeyID)
@ -773,7 +773,7 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
func generateOIDCToken(sub, iss, aud string, email string, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) { func generateOIDCToken(sub, iss, aud, email, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
so := new(jose.SignerOptions) so := new(jose.SignerOptions)
so.WithType("JWT") so.WithType("JWT")
so.WithHeader("kid", jwk.KeyID) so.WithHeader("kid", jwk.KeyID)

View file

@ -108,7 +108,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
// GetSSHBastion returns the bastion configuration, for the given pair user, // GetSSHBastion returns the bastion configuration, for the given pair user,
// hostname. // hostname.
func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname string) (*config.Bastion, error) { func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (*config.Bastion, error) {
if a.sshBastionFunc != nil { if a.sshBastionFunc != nil {
bs, err := a.sshBastionFunc(ctx, user, hostname) bs, err := a.sshBastionFunc(ctx, user, hostname)
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion") return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
@ -477,7 +477,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje
} }
// CheckSSHHost checks the given principal has been registered before. // CheckSSHHost checks the given principal has been registered before.
func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) { func (a *Authority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) {
if a.sshCheckHostFunc != nil { if a.sshCheckHostFunc != nil {
exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates()) exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates())
if err != nil { if err != nil {
@ -531,5 +531,5 @@ func (a *Authority) getAddUserCommand(principal string) string {
} else { } else {
cmd = a.config.SSH.AddUserCommand cmd = a.config.SSH.AddUserCommand
} }
return strings.Replace(cmd, "<principal>", principal, -1) return strings.ReplaceAll(cmd, "<principal>", principal)
} }

View file

@ -55,10 +55,10 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
if len(crt.Subject.StreetAddress) == 0 && def.StreetAddress != "" { if len(crt.Subject.StreetAddress) == 0 && def.StreetAddress != "" {
crt.Subject.StreetAddress = append(crt.Subject.StreetAddress, def.StreetAddress) crt.Subject.StreetAddress = append(crt.Subject.StreetAddress, def.StreetAddress)
} }
if len(crt.Subject.SerialNumber) == 0 && def.SerialNumber != "" { if crt.Subject.SerialNumber == "" && def.SerialNumber != "" {
crt.Subject.SerialNumber = def.SerialNumber crt.Subject.SerialNumber = def.SerialNumber
} }
if len(crt.Subject.CommonName) == 0 && def.CommonName != "" { if crt.Subject.CommonName == "" && def.CommonName != "" {
crt.Subject.CommonName = def.CommonName crt.Subject.CommonName = def.CommonName
} }
return nil return nil
@ -387,14 +387,14 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
return errs.Wrap(http.StatusInternalServerError, err, return errs.Wrap(http.StatusInternalServerError, err,
"authority.Revoke; could not get ID for token") "authority.Revoke; could not get ID for token")
} }
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID)) opts = append(opts,
opts = append(opts, errs.WithKeyVal("tokenID", rci.TokenID)) errs.WithKeyVal("provisionerID", rci.ProvisionerID),
} else { errs.WithKeyVal("tokenID", rci.TokenID),
)
} else if p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt); err == nil {
// Load the Certificate provisioner if one exists. // Load the Certificate provisioner if one exists.
if p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt); err == nil { rci.ProvisionerID = p.GetID()
rci.ProvisionerID = p.GetID() opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
}
} }
if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod { if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod {

View file

@ -426,6 +426,7 @@ ZYtQ9Ot36qc=
{Id: stepOIDProvisioner, Value: []byte("foo")}, {Id: stepOIDProvisioner, Value: []byte("foo")},
{Id: []int{1, 1, 1}, Value: []byte("bar")}})) {Id: []int{1, 1, 1}, Value: []byte("bar")}}))
now := time.Now().UTC() now := time.Now().UTC()
// nolint:gocritic
enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{ enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{
NotBefore: now, NotBefore: now,
NotAfter: now.Add(365 * 24 * time.Hour), NotAfter: now.Add(365 * 24 * time.Hour),

View file

@ -345,7 +345,7 @@ func readACMEError(r io.ReadCloser) error {
ae := new(acme.Error) ae := new(acme.Error)
err = json.Unmarshal(b, &ae) err = json.Unmarshal(b, &ae)
// If we successfully marshaled to an ACMEError then return the ACMEError. // If we successfully marshaled to an ACMEError then return the ACMEError.
if err != nil || len(ae.Error()) == 0 { if err != nil || ae.Error() == "" {
fmt.Printf("b = %s\n", b) fmt.Printf("b = %s\n", b)
// Throw up our hands. // Throw up our hands.
return errors.Errorf("%s", b) return errors.Errorf("%s", b)

View file

@ -1247,6 +1247,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
Type: "Certificate", Type: "Certificate",
Bytes: leaf.Raw, Bytes: leaf.Raw,
}) })
// nolint:gocritic
certBytes := append(leafb, leafb...) certBytes := append(leafb, leafb...)
certBytes = append(certBytes, leafb...) certBytes = append(certBytes, leafb...)
ac := &ACMEClient{ ac := &ACMEClient{

View file

@ -70,7 +70,7 @@ func NewAdminClient(endpoint string, opts ...ClientOption) (*AdminClient, error)
}, nil }, nil
} }
func (c *AdminClient) generateAdminToken(path string) (string, error) { func (c *AdminClient) generateAdminToken(urlPath string) (string, error) {
// A random jwt id will be used to identify duplicated tokens // A random jwt id will be used to identify duplicated tokens
jwtID, err := randutil.Hex(64) // 256 bits jwtID, err := randutil.Hex(64) // 256 bits
if err != nil { if err != nil {
@ -82,7 +82,7 @@ func (c *AdminClient) generateAdminToken(path string) (string, error) {
token.WithJWTID(jwtID), token.WithJWTID(jwtID),
token.WithKid(c.x5cJWK.KeyID), token.WithKid(c.x5cJWK.KeyID),
token.WithIssuer(c.x5cIssuer), token.WithIssuer(c.x5cIssuer),
token.WithAudience(path), token.WithAudience(urlPath),
token.WithValidity(now, now.Add(token.DefaultValidity)), token.WithValidity(now, now.Add(token.DefaultValidity)),
token.WithX5CCerts(c.x5cCertStrs), token.WithX5CCerts(c.x5cCertStrs),
} }
@ -348,14 +348,15 @@ func (c *AdminClient) GetProvisioner(opts ...ProvisionerOption) (*linkedca.Provi
return nil, err return nil, err
} }
var u *url.URL var u *url.URL
if len(o.id) > 0 { switch {
case len(o.id) > 0:
u = c.endpoint.ResolveReference(&url.URL{ u = c.endpoint.ResolveReference(&url.URL{
Path: "/admin/provisioners/id", Path: "/admin/provisioners/id",
RawQuery: o.rawQuery(), RawQuery: o.rawQuery(),
}) })
} else if len(o.name) > 0 { case len(o.name) > 0:
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)}) u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
} else { default:
return nil, errors.New("must set either name or id in method options") return nil, errors.New("must set either name or id in method options")
} }
tok, err := c.generateAdminToken(u.Path) tok, err := c.generateAdminToken(u.Path)
@ -456,14 +457,15 @@ func (c *AdminClient) RemoveProvisioner(opts ...ProvisionerOption) error {
return err return err
} }
if len(o.id) > 0 { switch {
case len(o.id) > 0:
u = c.endpoint.ResolveReference(&url.URL{ u = c.endpoint.ResolveReference(&url.URL{
Path: path.Join(adminURLPrefix, "provisioners/id"), Path: path.Join(adminURLPrefix, "provisioners/id"),
RawQuery: o.rawQuery(), RawQuery: o.rawQuery(),
}) })
} else if len(o.name) > 0 { case len(o.name) > 0:
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)}) u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
} else { default:
return errors.New("must set either name or id in method options") return errors.New("must set either name or id in method options")
} }
tok, err := c.generateAdminToken(u.Path) tok, err := c.generateAdminToken(u.Path)

View file

@ -30,7 +30,7 @@ func Bootstrap(token string) (*Client, error) {
// Validate bootstrap token // Validate bootstrap token
switch { switch {
case len(claims.SHA) == 0: case claims.SHA == "":
return nil, errors.New("invalid bootstrap token: sha claim is not present") return nil, errors.New("invalid bootstrap token: sha claim is not present")
case !strings.HasPrefix(strings.ToLower(claims.Audience[0]), "http"): case !strings.HasPrefix(strings.ToLower(claims.Audience[0]), "http"):
return nil, errors.New("invalid bootstrap token: aud claim is not a url") return nil, errors.New("invalid bootstrap token: aud claim is not a url")

View file

@ -88,9 +88,9 @@ func WithIssuerPassword(password []byte) Option {
} }
// WithDatabase sets the given authority database to the CA options. // WithDatabase sets the given authority database to the CA options.
func WithDatabase(db db.AuthDB) Option { func WithDatabase(d db.AuthDB) Option {
return func(o *options) { return func(o *options) {
o.database = db o.database = d
} }
} }
@ -113,17 +113,17 @@ type CA struct {
} }
// New creates and initializes the CA with the given configuration and options. // New creates and initializes the CA with the given configuration and options.
func New(config *config.Config, opts ...Option) (*CA, error) { func New(cfg *config.Config, opts ...Option) (*CA, error) {
ca := &CA{ ca := &CA{
config: config, config: cfg,
opts: new(options), opts: new(options),
} }
ca.opts.apply(opts) ca.opts.apply(opts)
return ca.Init(config) return ca.Init(cfg)
} }
// Init initializes the CA with the given configuration. // Init initializes the CA with the given configuration.
func (ca *CA) Init(config *config.Config) (*CA, error) { func (ca *CA) Init(cfg *config.Config) (*CA, error) {
// Set password, it's ok to set nil password, the ca will prompt for them if // Set password, it's ok to set nil password, the ca will prompt for them if
// they are required. // they are required.
opts := []authority.Option{ opts := []authority.Option{
@ -140,7 +140,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
opts = append(opts, authority.WithDatabase(ca.opts.database)) opts = append(opts, authority.WithDatabase(ca.opts.database))
} }
auth, err := authority.New(config, opts...) auth, err := authority.New(cfg, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -166,8 +166,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
}) })
//Add ACME api endpoints in /acme and /1.0/acme //Add ACME api endpoints in /acme and /1.0/acme
dns := config.DNSNames[0] dns := cfg.DNSNames[0]
u, err := url.Parse("https://" + config.Address) u, err := url.Parse("https://" + cfg.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -179,7 +179,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
// ACME Router // ACME Router
prefix := "acme" prefix := "acme"
var acmeDB acme.DB var acmeDB acme.DB
if config.DB == nil { if cfg.DB == nil {
acmeDB = nil acmeDB = nil
} else { } else {
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
@ -188,7 +188,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
} }
} }
acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
Backdate: *config.AuthorityConfig.Backdate, Backdate: *cfg.AuthorityConfig.Backdate,
DB: acmeDB, DB: acmeDB,
DNS: dns, DNS: dns,
Prefix: prefix, Prefix: prefix,
@ -204,7 +204,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
}) })
// Admin API Router // Admin API Router
if config.AuthorityConfig.EnableAdmin { if cfg.AuthorityConfig.EnableAdmin {
adminDB := auth.GetAdminDatabase() adminDB := auth.GetAdminDatabase()
if adminDB != nil { if adminDB != nil {
adminHandler := adminAPI.NewHandler(auth) adminHandler := adminAPI.NewHandler(auth)
@ -248,8 +248,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
//dumpRoutes(mux) //dumpRoutes(mux)
// Add monitoring if configured // Add monitoring if configured
if len(config.Monitoring) > 0 { if len(cfg.Monitoring) > 0 {
m, err := monitoring.New(config.Monitoring) m, err := monitoring.New(cfg.Monitoring)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -258,8 +258,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
} }
// Add logger if configured // Add logger if configured
if len(config.Logger) > 0 { if len(cfg.Logger) > 0 {
logger, err := logging.New("ca", config.Logger) logger, err := logging.New("ca", cfg.Logger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -267,16 +267,16 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
insecureHandler = logger.Middleware(insecureHandler) insecureHandler = logger.Middleware(insecureHandler)
} }
ca.srv = server.New(config.Address, handler, tlsConfig) ca.srv = server.New(cfg.Address, handler, tlsConfig)
// only start the insecure server if the insecure address is configured // only start the insecure server if the insecure address is configured
// and, currently, also only when it should serve SCEP endpoints. // and, currently, also only when it should serve SCEP endpoints.
if ca.shouldServeSCEPEndpoints() && config.InsecureAddress != "" { if ca.shouldServeSCEPEndpoints() && cfg.InsecureAddress != "" {
// TODO: instead opt for having a single server.Server but two // TODO: instead opt for having a single server.Server but two
// http.Servers handling the HTTP and HTTPS handler? The latter // http.Servers handling the HTTP and HTTPS handler? The latter
// will probably introduce more complexity in terms of graceful // will probably introduce more complexity in terms of graceful
// reload. // reload.
ca.insecureSrv = server.New(config.InsecureAddress, insecureHandler, nil) ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil)
} }
return ca, nil return ca, nil
@ -285,24 +285,24 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
// Run starts the CA calling to the server ListenAndServe method. // Run starts the CA calling to the server ListenAndServe method.
func (ca *CA) Run() error { func (ca *CA) Run() error {
var wg sync.WaitGroup var wg sync.WaitGroup
errors := make(chan error, 1) errs := make(chan error, 1)
if ca.insecureSrv != nil { if ca.insecureSrv != nil {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
errors <- ca.insecureSrv.ListenAndServe() errs <- ca.insecureSrv.ListenAndServe()
}() }()
} }
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
errors <- ca.srv.ListenAndServe() errs <- ca.srv.ListenAndServe()
}() }()
// wait till error occurs; ensures the servers keep listening // wait till error occurs; ensures the servers keep listening
err := <-errors err := <-errs
wg.Wait() wg.Wait()
@ -331,7 +331,7 @@ func (ca *CA) Stop() error {
// Reload reloads the configuration of the CA and calls to the server Reload // Reload reloads the configuration of the CA and calls to the server Reload
// method. // method.
func (ca *CA) Reload() error { func (ca *CA) Reload() error {
config, err := config.LoadConfiguration(ca.opts.configFile) cfg, err := config.LoadConfiguration(ca.opts.configFile)
if err != nil { if err != nil {
return errors.Wrap(err, "error reloading ca configuration") return errors.Wrap(err, "error reloading ca configuration")
} }
@ -343,12 +343,12 @@ func (ca *CA) Reload() error {
} }
// Do not allow reload if the database configuration has changed. // Do not allow reload if the database configuration has changed.
if !reflect.DeepEqual(ca.config.DB, config.DB) { if !reflect.DeepEqual(ca.config.DB, cfg.DB) {
logContinue("Reload failed because the database configuration has changed.") logContinue("Reload failed because the database configuration has changed.")
return errors.New("error reloading ca: database configuration cannot change") return errors.New("error reloading ca: database configuration cannot change")
} }
newCA, err := New(config, newCA, err := New(cfg,
WithPassword(ca.opts.password), WithPassword(ca.opts.password),
WithSSHHostPassword(ca.opts.sshHostPassword), WithSSHHostPassword(ca.opts.sshHostPassword),
WithSSHUserPassword(ca.opts.sshUserPassword), WithSSHUserPassword(ca.opts.sshUserPassword),

View file

@ -322,7 +322,7 @@ ZEp7knvU2psWRw==
assert.Equals(t, intermediate, realIntermediate) assert.Equals(t, intermediate, realIntermediate)
} else { } else {
err := readError(body) err := readError(body)
if len(tc.errMsg) == 0 { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
assert.HasPrefix(t, err.Error(), tc.errMsg) assert.HasPrefix(t, err.Error(), tc.errMsg)
@ -375,7 +375,7 @@ func TestCAProvisioners(t *testing.T) {
assert.Equals(t, a, b) assert.Equals(t, a, b)
} else { } else {
err := readError(body) err := readError(body)
if len(tc.errMsg) == 0 { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
assert.HasPrefix(t, err.Error(), tc.errMsg) assert.HasPrefix(t, err.Error(), tc.errMsg)
@ -436,7 +436,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
assert.Equals(t, ek.Key, tc.expectedKey) assert.Equals(t, ek.Key, tc.expectedKey)
} else { } else {
err := readError(body) err := readError(body)
if len(tc.errMsg) == 0 { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
assert.HasPrefix(t, err.Error(), tc.errMsg) assert.HasPrefix(t, err.Error(), tc.errMsg)
@ -497,7 +497,7 @@ func TestCARoot(t *testing.T) {
assert.Equals(t, root.RootPEM.Certificate, rootCrt) assert.Equals(t, root.RootPEM.Certificate, rootCrt)
} else { } else {
err := readError(body) err := readError(body)
if len(tc.errMsg) == 0 { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
assert.HasPrefix(t, err.Error(), tc.errMsg) assert.HasPrefix(t, err.Error(), tc.errMsg)
@ -665,7 +665,7 @@ func TestCARenew(t *testing.T) {
assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions) assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions)
} else { } else {
err := readError(body) err := readError(body)
if len(tc.errMsg) == 0 { if tc.errMsg == "" {
assert.FatalError(t, errors.New("must validate response error")) assert.FatalError(t, errors.New("must validate response error"))
} }
assert.HasPrefix(t, err.Error(), tc.errMsg) assert.HasPrefix(t, err.Error(), tc.errMsg)

View file

@ -74,17 +74,17 @@ func (c *uaClient) SetTransport(tr http.RoundTripper) {
c.Client.Transport = tr c.Client.Transport = tr
} }
func (c *uaClient) Get(url string) (*http.Response, error) { func (c *uaClient) Get(u string) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", u, nil)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "new request GET %s failed", url) return nil, errors.Wrapf(err, "new request GET %s failed", u)
} }
req.Header.Set("User-Agent", UserAgent) req.Header.Set("User-Agent", UserAgent)
return c.Client.Do(req) return c.Client.Do(req)
} }
func (c *uaClient) Post(url, contentType string, body io.Reader) (*http.Response, error) { func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", url, body) req, err := http.NewRequest("POST", u, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -305,7 +305,7 @@ func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile strin
err error err error
opts []jose.Option opts []jose.Option
) )
if len(passwordFile) != 0 { if passwordFile != "" {
opts = append(opts, jose.WithPasswordFile(passwordFile)) opts = append(opts, jose.WithPasswordFile(passwordFile))
} }
blk, err := pemutil.Serialize(key) blk, err := pemutil.Serialize(key)
@ -326,14 +326,14 @@ func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile strin
for _, e := range o.x5cCert.Extensions { for _, e := range o.x5cCert.Extensions {
if e.Id.Equal(stepOIDProvisioner) { if e.Id.Equal(stepOIDProvisioner) {
var provisioner stepProvisionerASN1 var prov stepProvisionerASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { if _, err := asn1.Unmarshal(e.Value, &prov); err != nil {
return errors.Wrap(err, "error unmarshaling provisioner OID from certificate") return errors.Wrap(err, "error unmarshaling provisioner OID from certificate")
} }
o.x5cIssuer = string(provisioner.Name) o.x5cIssuer = string(prov.Name)
} }
} }
if len(o.x5cIssuer) == 0 { if o.x5cIssuer == "" {
return errors.New("provisioner extension not found in certificate") return errors.New("provisioner extension not found in certificate")
} }
@ -631,7 +631,7 @@ retry:
// do not match. // do not match.
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) { func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
var retried bool var retried bool
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1)) sha256Sum = strings.ToLower(strings.ReplaceAll(sha256Sum, "-", ""))
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum}) u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
retry: retry:
resp, err := newInsecureClient().Get(u.String()) resp, err := newInsecureClient().Get(u.String())
@ -651,7 +651,7 @@ retry:
} }
// verify the sha256 // verify the sha256
sum := sha256.Sum256(root.RootPEM.Raw) sum := sha256.Sum256(root.RootPEM.Raw)
if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) { if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) {
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match") return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
} }
return &root, nil return &root, nil
@ -1066,16 +1066,16 @@ retry:
} }
return nil, readError(resp.Body) return nil, readError(resp.Body)
} }
var config api.SSHConfigResponse var cfg api.SSHConfigResponse
if err := readJSON(resp.Body, &config); err != nil { if err := readJSON(resp.Body, &cfg); err != nil {
return nil, errors.Wrapf(err, "error reading %s", u) return nil, errors.Wrapf(err, "error reading %s", u)
} }
return &config, nil return &cfg, nil
} }
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the // SSHCheckHost performs the POST /ssh/check-host request to the CA with the
// given principal. // given principal.
func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrincipalResponse, error) { func (c *Client) SSHCheckHost(principal, token string) (*api.SSHCheckPrincipalResponse, error) {
var retried bool var retried bool
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{ body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
Type: provisioner.SSHHostCert, Type: provisioner.SSHHostCert,

View file

@ -135,7 +135,7 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
return csr return csr
} }
func equalJSON(t *testing.T, a interface{}, b interface{}) bool { func equalJSON(t *testing.T, a, b interface{}) bool {
if reflect.DeepEqual(a, b) { if reflect.DeepEqual(a, b) {
return true return true
} }

View file

@ -187,11 +187,12 @@ func TestLoadClient(t *testing.T) {
} else { } else {
gotTransport := got.Client.Transport.(*http.Transport) gotTransport := got.Client.Transport.(*http.Transport)
wantTransport := tt.want.Client.Transport.(*http.Transport) wantTransport := tt.want.Client.Transport.(*http.Transport)
if gotTransport.TLSClientConfig.GetClientCertificate == nil { switch {
case gotTransport.TLSClientConfig.GetClientCertificate == nil:
t.Error("LoadClient() transport does not define GetClientCertificate") t.Error("LoadClient() transport does not define GetClientCertificate")
} else if !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()) { case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()):
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
} else { default:
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil) crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
if err != nil { if err != nil {
t.Errorf("LoadClient() GetClientCertificate error = %v", err) t.Errorf("LoadClient() GetClientCertificate error = %v", err)

View file

@ -105,7 +105,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
tr := getDefaultTransport(tlsConfig) tr := getDefaultTransport(tlsConfig)
// Use mutable tls.Config on renew // Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
@ -154,7 +154,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
// Update renew function with transport // Update renew function with transport
tr := getDefaultTransport(tlsConfig) tr := getDefaultTransport(tlsConfig)
// Use mutable tls.Config on renew // Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx) // tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
@ -195,7 +195,7 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net
} }
// buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport. // buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport.
// nolint:unused // nolint:unused,gocritic
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) { func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) { return func(ctx context.Context, network, addr string) (net.Conn, error) {
d := getDefaultDialer() d := getDefaultDialer()
@ -253,6 +253,8 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
return nil, err return nil, err
} }
// nolint:gocritic
// using a new variable for clarity
chain := append(certPEM, caPEM...) chain := append(certPEM, caPEM...)
cert, err := tls.X509KeyPair(chain, keyPEM) cert, err := tls.X509KeyPair(chain, keyPEM)
if err != nil { if err != nil {

View file

@ -29,9 +29,7 @@ func init() {
}) })
} }
var now = func() time.Time { var now = time.Now
return time.Now()
}
// The actual regular expression that matches a certificate authority is: // The actual regular expression that matches a certificate authority is:
// ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$ // ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$

View file

@ -12,7 +12,6 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
"log"
"net" "net"
"os" "os"
"reflect" "reflect"
@ -103,7 +102,7 @@ MHcCAQEEIN51Rgg6YcQVLeCRzumdw4pjM3VWqFIdCbnsV3Up1e/goAoGCCqGSM49
AwEHoUQDQgAEjJIcDhvvxi7gu4aFkiW/8+E3BfPhmhXU5RlDQusre+MHXc7XYMtk AwEHoUQDQgAEjJIcDhvvxi7gu4aFkiW/8+E3BfPhmhXU5RlDQusre+MHXc7XYMtk
Lm6PXPeTF1DNdS21Ju1G/j1yUykGJOmxkg== Lm6PXPeTF1DNdS21Ju1G/j1yUykGJOmxkg==
-----END EC PRIVATE KEY-----` -----END EC PRIVATE KEY-----`
// nolint:unused,deadcode // nolint:unused,deadcode,gocritic
testIntermediateKey = `-----BEGIN EC PRIVATE KEY----- testIntermediateKey = `-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIMMX/XkXGnRDD4fYu7Z4rHACdJn/iyOy2UTwsv+oZ0C+oAoGCCqGSM49 MHcCAQEEIMMX/XkXGnRDD4fYu7Z4rHACdJn/iyOy2UTwsv+oZ0C+oAoGCCqGSM49
AwEHoUQDQgAE8u6rGAFj5CZpdzzMogLwUyCMnp0X9wtv4OKDRcpzkYf9PU5GuGA6 AwEHoUQDQgAE8u6rGAFj5CZpdzzMogLwUyCMnp0X9wtv4OKDRcpzkYf9PU5GuGA6
@ -190,7 +189,7 @@ func (b *badSigner) Public() crypto.PublicKey {
return b.pub return b.pub
} }
func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { func (b *badSigner) Sign(rnd io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
return nil, fmt.Errorf("💥") return nil, fmt.Errorf("💥")
} }
@ -730,7 +729,7 @@ func TestCloudCAS_RevokeCertificate(t *testing.T) {
func Test_createCertificateID(t *testing.T) { func Test_createCertificateID(t *testing.T) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
setTeeReader(t, buf) setTeeReader(t, buf)
uuid, err := uuid.NewRandomFromReader(rand.Reader) id, err := uuid.NewRandomFromReader(rand.Reader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -741,7 +740,7 @@ func Test_createCertificateID(t *testing.T) {
want string want string
wantErr bool wantErr bool
}{ }{
{"ok", uuid.String(), false}, {"ok", id.String(), false},
{"fail", "", true}, {"fail", "", true},
} }
for _, tt := range tests { for _, tt := range tests {
@ -858,7 +857,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) {
return lis.Dial() return lis.Dial()
})) }))
if err != nil { if err != nil {
log.Fatal(err) t.Fatal(err)
} }
client, err := lroauto.NewOperationsClient(context.Background(), option.WithGRPCConn(conn)) client, err := lroauto.NewOperationsClient(context.Background(), option.WithGRPCConn(conn))

View file

@ -19,9 +19,7 @@ func init() {
}) })
} }
var now = func() time.Time { var now = time.Now
return time.Now()
}
// SoftCAS implements a Certificate Authority Service using Golang or KMS // SoftCAS implements a Certificate Authority Service using Golang or KMS
// crypto. This is the default CAS used in step-ca. // crypto. This is the default CAS used in step-ca.

View file

@ -133,7 +133,7 @@ func (b *badSigner) Public() crypto.PublicKey {
return testSigner.Public() return testSigner.Public()
} }
func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { func (b *badSigner) Sign(_ io.Reader, _ []byte, _ crypto.SignerOpts) ([]byte, error) {
return nil, fmt.Errorf("💥") return nil, fmt.Errorf("💥")
} }

View file

@ -90,9 +90,9 @@ func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R
return nil, apiv1.ErrNotImplemented{Message: "stepCAS does not support mTLS renewals"} return nil, apiv1.ErrNotImplemented{Message: "stepCAS does not support mTLS renewals"}
} }
// RevokeCertificate revokes a certificate.
func (s *StepCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) { func (s *StepCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
switch { if req.SerialNumber == "" && req.Certificate == nil {
case req.SerialNumber == "" && req.Certificate == nil:
return nil, errors.New("revokeCertificateRequest `serialNumber` or `certificate` are required") return nil, errors.New("revokeCertificateRequest `serialNumber` or `certificate` are required")
} }

View file

@ -19,9 +19,7 @@ const defaultValidity = 5 * time.Minute
// timeNow returns the current time. // timeNow returns the current time.
// This method is used for unit testing purposes. // This method is used for unit testing purposes.
var timeNow = func() time.Time { var timeNow = time.Now
return time.Now()
}
type x5cIssuer struct { type x5cIssuer struct {
caURL *url.URL caURL *url.URL

View file

@ -22,7 +22,7 @@ func (b noneSigner) Public() crypto.PublicKey {
return []byte(b) return []byte(b)
} }
func (b noneSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { func (b noneSigner) Sign(rnd io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
return digest, nil return digest, nil
} }

View file

@ -24,10 +24,10 @@ import (
func main() { func main() {
var credentialsFile, region string var credentialsFile, region string
var ssh bool var enableSSH bool
flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the AWS KMS credentials.") flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the AWS KMS credentials.")
flag.StringVar(&region, "region", "", "AWS KMS region name.") flag.StringVar(&region, "region", "", "AWS KMS region name.")
flag.BoolVar(&ssh, "ssh", false, "Create SSH keys.") flag.BoolVar(&enableSSH, "ssh", false, "Create SSH keys.")
flag.Usage = usage flag.Usage = usage
flag.Parse() flag.Parse()
@ -47,7 +47,7 @@ func main() {
fatal(err) fatal(err)
} }
if ssh { if enableSSH {
ui.Println() ui.Println()
if err := createSSH(c); err != nil { if err := createSSH(c); err != nil {
fatal(err) fatal(err)
@ -120,7 +120,7 @@ func createX509(c *awskms.KMS) error {
return err return err
} }
if err = fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: b, Bytes: b,
}), 0600); err != nil { }), 0600); err != nil {
@ -163,7 +163,7 @@ func createX509(c *awskms.KMS) error {
return err return err
} }
if err = fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: b, Bytes: b,
}), 0600); err != nil { }), 0600); err != nil {
@ -193,7 +193,7 @@ func createSSH(c *awskms.KMS) error {
return err return err
} }
if err = fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { if err := fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
return err return err
} }
@ -214,7 +214,7 @@ func createSSH(c *awskms.KMS) error {
return err return err
} }
if err = fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { if err := fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
return err return err
} }

View file

@ -191,8 +191,8 @@ var placeholderString = regexp.MustCompile(`<.*?>`)
func stringifyFlag(f cli.Flag) string { func stringifyFlag(f cli.Flag) string {
fv := flagValue(f) fv := flagValue(f)
usage := fv.FieldByName("Usage").String() usg := fv.FieldByName("Usage").String()
placeholder := placeholderString.FindString(usage) placeholder := placeholderString.FindString(usg)
if placeholder == "" { if placeholder == "" {
switch f.(type) { switch f.(type) {
case cli.BoolFlag, cli.BoolTFlag: case cli.BoolFlag, cli.BoolTFlag:
@ -200,5 +200,5 @@ func stringifyFlag(f cli.Flag) string {
placeholder = "<value>" placeholder = "<value>"
} }
} }
return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usage return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usg
} }

View file

@ -27,13 +27,13 @@ func main() {
var credentialsFile string var credentialsFile string
var project, location, ring string var project, location, ring string
var protectionLevelName string var protectionLevelName string
var ssh bool var enableSSH bool
flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the Google's Cloud KMS credentials.") flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the Google's Cloud KMS credentials.")
flag.StringVar(&project, "project", "", "Google Cloud Project ID.") flag.StringVar(&project, "project", "", "Google Cloud Project ID.")
flag.StringVar(&location, "location", "global", "Cloud KMS location name.") flag.StringVar(&location, "location", "global", "Cloud KMS location name.")
flag.StringVar(&ring, "ring", "pki", "Cloud KMS ring name.") flag.StringVar(&ring, "ring", "pki", "Cloud KMS ring name.")
flag.StringVar(&protectionLevelName, "protection-level", "SOFTWARE", "Protection level to use, SOFTWARE or HSM.") flag.StringVar(&protectionLevelName, "protection-level", "SOFTWARE", "Protection level to use, SOFTWARE or HSM.")
flag.BoolVar(&ssh, "ssh", false, "Create SSH keys.") flag.BoolVar(&enableSSH, "ssh", false, "Create SSH keys.")
flag.Usage = usage flag.Usage = usage
flag.Parse() flag.Parse()
@ -77,7 +77,7 @@ func main() {
fatal(err) fatal(err)
} }
if ssh { if enableSSH {
ui.Println() ui.Println()
if err := createSSH(c, project, location, ring, protectionLevel); err != nil { if err := createSSH(c, project, location, ring, protectionLevel); err != nil {
fatal(err) fatal(err)
@ -153,7 +153,7 @@ func createPKI(c *cloudkms.CloudKMS, project, location, keyRing string, protecti
return err return err
} }
if err = fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: b, Bytes: b,
}), 0600); err != nil { }), 0600); err != nil {
@ -197,7 +197,7 @@ func createPKI(c *cloudkms.CloudKMS, project, location, keyRing string, protecti
return err return err
} }
if err = fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: b, Bytes: b,
}), 0600); err != nil { }), 0600); err != nil {
@ -230,7 +230,7 @@ func createSSH(c *cloudkms.CloudKMS, project, location, keyRing string, protecti
return err return err
} }
if err = fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { if err := fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
return err return err
} }
@ -252,7 +252,7 @@ func createSSH(c *cloudkms.CloudKMS, project, location, keyRing string, protecti
return err return err
} }
if err = fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil { if err := fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
return err return err
} }

View file

@ -329,7 +329,7 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts { if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts {
if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{ if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.RootObject, Name: c.RootObject,
Certificate: root, Certificate: root,
}); err != nil { }); err != nil {
@ -337,7 +337,7 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
} }
if err = fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: b, Bytes: b,
}), 0600); err != nil { }), 0600); err != nil {
@ -406,7 +406,7 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts { if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts {
if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{ if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.CrtObject, Name: c.CrtObject,
Certificate: intermediate, Certificate: intermediate,
}); err != nil { }); err != nil {
@ -414,7 +414,7 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
} }
if err = fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: b, Bytes: b,
}), 0600); err != nil { }), 0600); err != nil {

View file

@ -228,7 +228,7 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
if cm, ok := k.(kms.CertificateManager); ok { if cm, ok := k.(kms.CertificateManager); ok {
if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{ if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.RootSlot, Name: c.RootSlot,
Certificate: root, Certificate: root,
}); err != nil { }); err != nil {
@ -236,7 +236,7 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
} }
if err = fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: b, Bytes: b,
}), 0600); err != nil { }), 0600); err != nil {
@ -305,7 +305,7 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
if cm, ok := k.(kms.CertificateManager); ok { if cm, ok := k.(kms.CertificateManager); ok {
if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{ if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.CrtSlot, Name: c.CrtSlot,
Certificate: intermediate, Certificate: intermediate,
}); err != nil { }); err != nil {
@ -313,7 +313,7 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
} }
if err = fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: b, Bytes: b,
}), 0600); err != nil { }), 0600); err != nil {

View file

@ -79,13 +79,13 @@ func appAction(ctx *cli.Context) error {
} }
configFile := ctx.Args().Get(0) configFile := ctx.Args().Get(0)
config, err := config.LoadConfiguration(configFile) cfg, err := config.LoadConfiguration(configFile)
if err != nil { if err != nil {
fatal(err) fatal(err)
} }
if config.AuthorityConfig != nil { if cfg.AuthorityConfig != nil {
if token == "" && strings.EqualFold(config.AuthorityConfig.DeploymentType, pki.LinkedDeployment.String()) { if token == "" && strings.EqualFold(cfg.AuthorityConfig.DeploymentType, pki.LinkedDeployment.String()) {
return errors.New(`'step-ca' requires the '--token' flag for linked deploy type. return errors.New(`'step-ca' requires the '--token' flag for linked deploy type.
To get a linked authority token: To get a linked authority token:
@ -136,7 +136,7 @@ To get a linked authority token:
} }
} }
srv, err := ca.New(config, srv, err := ca.New(cfg,
ca.WithConfigFile(configFile), ca.WithConfigFile(configFile),
ca.WithPassword(password), ca.WithPassword(password),
ca.WithSSHHostPassword(sshHostPassword), ca.WithSSHHostPassword(sshHostPassword),

View file

@ -63,11 +63,11 @@ func exportAction(ctx *cli.Context) error {
passwordFile := ctx.String("password-file") passwordFile := ctx.String("password-file")
issuerPasswordFile := ctx.String("issuer-password-file") issuerPasswordFile := ctx.String("issuer-password-file")
config, err := config.LoadConfiguration(configFile) cfg, err := config.LoadConfiguration(configFile)
if err != nil { if err != nil {
return err return err
} }
if err := config.Validate(); err != nil { if err := cfg.Validate(); err != nil {
return err return err
} }
@ -76,19 +76,19 @@ func exportAction(ctx *cli.Context) error {
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", passwordFile) return errors.Wrapf(err, "error reading %s", passwordFile)
} }
config.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace)) cfg.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace))
} }
if issuerPasswordFile != "" { if issuerPasswordFile != "" {
b, err := ioutil.ReadFile(issuerPasswordFile) b, err := ioutil.ReadFile(issuerPasswordFile)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", issuerPasswordFile) return errors.Wrapf(err, "error reading %s", issuerPasswordFile)
} }
if config.AuthorityConfig.CertificateIssuer != nil { if cfg.AuthorityConfig.CertificateIssuer != nil {
config.AuthorityConfig.CertificateIssuer.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace)) cfg.AuthorityConfig.CertificateIssuer.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace))
} }
} }
auth, err := authority.New(config) auth, err := authority.New(cfg)
if err != nil { if err != nil {
return err return err
} }

View file

@ -103,8 +103,8 @@ func onboardAction(ctx *cli.Context) error {
return errors.Wrap(msg, "error receiving onboarding guide") return errors.Wrap(msg, "error receiving onboarding guide")
} }
var config onboardingConfiguration var cfg onboardingConfiguration
if err := readJSON(res.Body, &config); err != nil { if err := readJSON(res.Body, &cfg); err != nil {
return errors.Wrap(err, "error unmarshaling response") return errors.Wrap(err, "error unmarshaling response")
} }
@ -112,16 +112,16 @@ func onboardAction(ctx *cli.Context) error {
if err != nil { if err != nil {
return err return err
} }
config.password = []byte(password) cfg.password = []byte(password)
ui.Println("Initializing step-ca with the following configuration:") ui.Println("Initializing step-ca with the following configuration:")
ui.PrintSelected("Name", config.Name) ui.PrintSelected("Name", cfg.Name)
ui.PrintSelected("DNS", config.DNS) ui.PrintSelected("DNS", cfg.DNS)
ui.PrintSelected("Address", config.Address) ui.PrintSelected("Address", cfg.Address)
ui.PrintSelected("Password", password) ui.PrintSelected("Password", password)
ui.Println() ui.Println()
caConfig, fp, err := onboardPKI(config) caConfig, fp, err := onboardPKI(cfg)
if err != nil { if err != nil {
return err return err
} }
@ -149,23 +149,23 @@ func onboardAction(ctx *cli.Context) error {
ui.Println("Initialized!") ui.Println("Initialized!")
ui.Println("Step CA is starting. Please return to the onboarding guide in your browser to continue.") ui.Println("Step CA is starting. Please return to the onboarding guide in your browser to continue.")
srv, err := ca.New(caConfig, ca.WithPassword(config.password)) srv, err := ca.New(caConfig, ca.WithPassword(cfg.password))
if err != nil { if err != nil {
fatal(err) fatal(err)
} }
go ca.StopReloaderHandler(srv) go ca.StopReloaderHandler(srv)
if err = srv.Run(); err != nil && err != http.ErrServerClosed { if err := srv.Run(); err != nil && err != http.ErrServerClosed {
fatal(err) fatal(err)
} }
return nil return nil
} }
func onboardPKI(config onboardingConfiguration) (*config.Config, string, error) { func onboardPKI(cfg onboardingConfiguration) (*config.Config, string, error) {
var opts = []pki.Option{ var opts = []pki.Option{
pki.WithAddress(config.Address), pki.WithAddress(cfg.Address),
pki.WithDNSNames([]string{config.DNS}), pki.WithDNSNames([]string{cfg.DNS}),
pki.WithProvisioner("admin"), pki.WithProvisioner("admin"),
} }
@ -179,25 +179,25 @@ func onboardPKI(config onboardingConfiguration) (*config.Config, string, error)
// Generate pki // Generate pki
ui.Println("Generating root certificate...") ui.Println("Generating root certificate...")
root, err := p.GenerateRootCertificate(config.Name, config.Name, config.Name, config.password) root, err := p.GenerateRootCertificate(cfg.Name, cfg.Name, cfg.Name, cfg.password)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
ui.Println("Generating intermediate certificate...") ui.Println("Generating intermediate certificate...")
err = p.GenerateIntermediateCertificate(config.Name, config.Name, config.Name, root, config.password) err = p.GenerateIntermediateCertificate(cfg.Name, cfg.Name, cfg.Name, root, cfg.password)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
// Write files to disk // Write files to disk
if err = p.WriteFiles(); err != nil { if err := p.WriteFiles(); err != nil {
return nil, "", err return nil, "", err
} }
// Generate provisioner // Generate provisioner
ui.Println("Generating admin provisioner...") ui.Println("Generating admin provisioner...")
if err = p.GenerateKeyPairs(config.password); err != nil { if err := p.GenerateKeyPairs(cfg.password); err != nil {
return nil, "", err return nil, "", err
} }
@ -211,7 +211,7 @@ func onboardPKI(config onboardingConfiguration) (*config.Config, string, error)
if err != nil { if err != nil {
return nil, "", errors.Wrapf(err, "error marshaling %s", p.GetCAConfigPath()) return nil, "", errors.Wrapf(err, "error marshaling %s", p.GetCAConfigPath())
} }
if err = fileutil.WriteFile(p.GetCAConfigPath(), b, 0666); err != nil { if err := fileutil.WriteFile(p.GetCAConfigPath(), b, 0666); err != nil {
return nil, "", errs.FileError(err, p.GetCAConfigPath()) return nil, "", errs.FileError(err, p.GetCAConfigPath())
} }

View file

@ -144,15 +144,15 @@ func TestUseToken(t *testing.T) {
} }
for name, tc := range tests { for name, tc := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
ok, err := tc.db.UseToken(tc.id, tc.tok) switch ok, err := tc.db.UseToken(tc.id, tc.tok); {
if err != nil { case err != nil:
if assert.NotNil(t, tc.want.err) { if assert.NotNil(t, tc.want.err) {
assert.HasPrefix(t, err.Error(), tc.want.err.Error()) assert.HasPrefix(t, err.Error(), tc.want.err.Error())
} }
assert.False(t, ok) assert.False(t, ok)
} else if ok { case ok:
assert.True(t, tc.want.ok) assert.True(t, tc.want.ok)
} else { default:
assert.False(t, tc.want.ok) assert.False(t, tc.want.ok)
} }
}) })

View file

@ -378,6 +378,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) {
t.Errorf("SSHAgentKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("SSHAgentKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
// nolint:gocritic
switch s := got.(type) { switch s := got.(type) {
case *WrappedSSHSigner: case *WrappedSSHSigner:
gotPkS := s.Sshsigner.PublicKey().(*agent.Key).String() + "\n" gotPkS := s.Sshsigner.PublicKey().(*agent.Key).String() + "\n"
@ -562,6 +563,7 @@ func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
t.Errorf("SSHAgentKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("SSHAgentKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
// nolint:gocritic
switch tt.want.(type) { switch tt.want.(type) {
case ssh.PublicKey: case ssh.PublicKey:
// If we want a ssh.PublicKey, protote got to a // If we want a ssh.PublicKey, protote got to a

View file

@ -128,7 +128,7 @@ func GetTemplatesPath() string {
// GetProvisioners returns the map of provisioners on the given CA. // GetProvisioners returns the map of provisioners on the given CA.
func GetProvisioners(caURL, rootFile string) (provisioner.List, error) { func GetProvisioners(caURL, rootFile string) (provisioner.List, error) {
if len(rootFile) == 0 { if rootFile == "" {
rootFile = GetRootCAPath() rootFile = GetRootCAPath()
} }
client, err := ca.NewClient(caURL, ca.WithRootFile(rootFile)) client, err := ca.NewClient(caURL, ca.WithRootFile(rootFile))
@ -153,7 +153,7 @@ func GetProvisioners(caURL, rootFile string) (provisioner.List, error) {
// GetProvisionerKey returns the encrypted provisioner key with the for the // GetProvisionerKey returns the encrypted provisioner key with the for the
// given kid. // given kid.
func GetProvisionerKey(caURL, rootFile, kid string) (string, error) { func GetProvisionerKey(caURL, rootFile, kid string) (string, error) {
if len(rootFile) == 0 { if rootFile == "" {
rootFile = GetRootCAPath() rootFile = GetRootCAPath()
} }
client, err := ca.NewClient(caURL, ca.WithRootFile(rootFile)) client, err := ca.NewClient(caURL, ca.WithRootFile(rootFile))
@ -315,17 +315,17 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
// Use /home/step as the step path in helm configurations. // Use /home/step as the step path in helm configurations.
// Use the current step path when creating pki in files. // Use the current step path when creating pki in files.
var public, private, config string var public, private, cfg string
if p.options.isHelm { if p.options.isHelm {
public = "/home/step/certs" public = "/home/step/certs"
private = "/home/step/secrets" private = "/home/step/secrets"
config = "/home/step/config" cfg = "/home/step/config"
} else { } else {
public = GetPublicPath() public = GetPublicPath()
private = GetSecretsPath() private = GetSecretsPath()
config = GetConfigPath() cfg = GetConfigPath()
// Create directories // Create directories
dirs := []string{public, private, config, GetTemplatesPath()} dirs := []string{public, private, cfg, GetTemplatesPath()}
for _, name := range dirs { for _, name := range dirs {
if _, err := os.Stat(name); os.IsNotExist(err) { if _, err := os.Stat(name); os.IsNotExist(err) {
if err = os.MkdirAll(name, 0700); err != nil { if err = os.MkdirAll(name, 0700); err != nil {
@ -380,10 +380,10 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
if p.Ssh.UserKey, err = getPath(private, "ssh_user_ca_key"); err != nil { if p.Ssh.UserKey, err = getPath(private, "ssh_user_ca_key"); err != nil {
return nil, err return nil, err
} }
if p.defaults, err = getPath(config, "defaults.json"); err != nil { if p.defaults, err = getPath(cfg, "defaults.json"); err != nil {
return nil, err return nil, err
} }
if p.config, err = getPath(config, "ca.json"); err != nil { if p.config, err = getPath(cfg, "ca.json"); err != nil {
return nil, err return nil, err
} }
p.Defaults.CaConfig = p.config p.Defaults.CaConfig = p.config
@ -620,16 +620,17 @@ func (p *PKI) askFeedback() {
func (p *PKI) tellPKI() { func (p *PKI) tellPKI() {
ui.Println() ui.Println()
if p.casOptions.Is(apiv1.SoftCAS) { switch {
case p.casOptions.Is(apiv1.SoftCAS):
ui.PrintSelected("Root certificate", p.Root[0]) ui.PrintSelected("Root certificate", p.Root[0])
ui.PrintSelected("Root private key", p.RootKey[0]) ui.PrintSelected("Root private key", p.RootKey[0])
ui.PrintSelected("Root fingerprint", p.Defaults.Fingerprint) ui.PrintSelected("Root fingerprint", p.Defaults.Fingerprint)
ui.PrintSelected("Intermediate certificate", p.Intermediate) ui.PrintSelected("Intermediate certificate", p.Intermediate)
ui.PrintSelected("Intermediate private key", p.IntermediateKey) ui.PrintSelected("Intermediate private key", p.IntermediateKey)
} else if p.Defaults.Fingerprint != "" { case p.Defaults.Fingerprint != "":
ui.PrintSelected("Root certificate", p.Root[0]) ui.PrintSelected("Root certificate", p.Root[0])
ui.PrintSelected("Root fingerprint", p.Defaults.Fingerprint) ui.PrintSelected("Root fingerprint", p.Defaults.Fingerprint)
} else { default:
ui.Printf(`{{ "%s" | red }} {{ "Root certificate:" | bold }} failed to retrieve it from RA`+"\n", ui.IconBad) ui.Printf(`{{ "%s" | red }} {{ "Root certificate:" | bold }} failed to retrieve it from RA`+"\n", ui.IconBad)
} }
if p.options.enableSSH { if p.options.enableSSH {
@ -657,7 +658,7 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
authorityOptions = &p.casOptions authorityOptions = &p.casOptions
} }
config := &authconfig.Config{ cfg := &authconfig.Config{
Root: p.Root, Root: p.Root,
FederatedRoots: p.FederatedRoots, FederatedRoots: p.FederatedRoots,
IntermediateCert: p.Intermediate, IntermediateCert: p.Intermediate,
@ -681,7 +682,7 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
// Add linked as a deployment type to detect it on start and provide a // Add linked as a deployment type to detect it on start and provide a
// message if the token is not given. // message if the token is not given.
if p.options.deploymentType == LinkedDeployment { if p.options.deploymentType == LinkedDeployment {
config.AuthorityConfig.DeploymentType = LinkedDeployment.String() cfg.AuthorityConfig.DeploymentType = LinkedDeployment.String()
} }
// On standalone deployments add the provisioners to either the ca.json or // On standalone deployments add the provisioners to either the ca.json or
@ -711,7 +712,7 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
if p.options.enableSSH { if p.options.enableSSH {
enableSSHCA := true enableSSHCA := true
config.SSH = &authconfig.SSHConfig{ cfg.SSH = &authconfig.SSHConfig{
HostKey: p.Ssh.HostKey, HostKey: p.Ssh.HostKey,
UserKey: p.Ssh.UserKey, UserKey: p.Ssh.UserKey,
} }
@ -733,19 +734,19 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
// Apply configuration modifiers // Apply configuration modifiers
for _, o := range opt { for _, o := range opt {
if err := o(config); err != nil { if err := o(cfg); err != nil {
return nil, err return nil, err
} }
} }
// Set authority.enableAdmin to true // Set authority.enableAdmin to true
if p.options.enableAdmin { if p.options.enableAdmin {
config.AuthorityConfig.EnableAdmin = true cfg.AuthorityConfig.EnableAdmin = true
} }
if p.options.deploymentType == StandaloneDeployment { if p.options.deploymentType == StandaloneDeployment {
if !config.AuthorityConfig.EnableAdmin { if !cfg.AuthorityConfig.EnableAdmin {
config.AuthorityConfig.Provisioners = provisioners cfg.AuthorityConfig.Provisioners = provisioners
} else { } else {
// At this moment this code path is never used because `step ca // At this moment this code path is never used because `step ca
// init` will always set enableAdmin to false for a standalone // init` will always set enableAdmin to false for a standalone
@ -754,11 +755,11 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
// //
// Note that we might want to be able to define the database as a // Note that we might want to be able to define the database as a
// flag in `step ca init` so we can write to the proper place. // flag in `step ca init` so we can write to the proper place.
db, err := db.New(config.DB) _db, err := db.New(cfg.DB)
if err != nil { if err != nil {
return nil, err return nil, err
} }
adminDB, err := admindb.New(db.(nosql.DB), admin.DefaultAuthorityID) adminDB, err := admindb.New(_db.(nosql.DB), admin.DefaultAuthorityID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -788,7 +789,7 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
} }
} }
return config, nil return cfg, nil
} }
// Save stores the pki on a json file that will be used as the certificate // Save stores the pki on a json file that will be used as the certificate
@ -804,12 +805,12 @@ func (p *PKI) Save(opt ...ConfigOption) error {
// Generate and write ca.json // Generate and write ca.json
if !p.options.pkiOnly { if !p.options.pkiOnly {
config, err := p.GenerateConfig(opt...) cfg, err := p.GenerateConfig(opt...)
if err != nil { if err != nil {
return err return err
} }
b, err := json.MarshalIndent(config, "", "\t") b, err := json.MarshalIndent(cfg, "", "\t")
if err != nil { if err != nil {
return errors.Wrapf(err, "error marshaling %s", p.config) return errors.Wrapf(err, "error marshaling %s", p.config)
} }
@ -833,14 +834,14 @@ func (p *PKI) Save(opt ...ConfigOption) error {
} }
// Generate and write templates // Generate and write templates
if err := generateTemplates(config.Templates); err != nil { if err := generateTemplates(cfg.Templates); err != nil {
return err return err
} }
if config.DB != nil { if cfg.DB != nil {
ui.PrintSelected("Database folder", config.DB.DataSource) ui.PrintSelected("Database folder", cfg.DB.DataSource)
} }
if config.Templates != nil { if cfg.Templates != nil {
ui.PrintSelected("Templates folder", GetTemplatesPath()) ui.PrintSelected("Templates folder", GetTemplatesPath())
} }

View file

@ -198,14 +198,14 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
return return
} }
provisioner, ok := p.(*provisioner.SCEP) prov, ok := p.(*provisioner.SCEP)
if !ok { if !ok {
api.WriteError(w, errors.New("provisioner must be of type SCEP")) api.WriteError(w, errors.New("provisioner must be of type SCEP"))
return return
} }
ctx := r.Context() ctx := r.Context()
ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(provisioner)) ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov))
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }

View file

@ -5,6 +5,7 @@ import (
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"text/template" "text/template"
"github.com/Masterminds/sprig/v3" "github.com/Masterminds/sprig/v3"
@ -226,14 +227,11 @@ func (t *Template) Output(data interface{}) (Output, error) {
// backfill updates old templates with the required data. // backfill updates old templates with the required data.
func (t *Template) backfill(b []byte) { func (t *Template) backfill(b []byte) {
switch t.Name { if strings.EqualFold(t.Name, "sshd_config.tpl") && len(t.RequiredData) == 0 {
case "sshd_config.tpl": a := bytes.TrimSpace(b)
if len(t.RequiredData) == 0 { b := bytes.TrimSpace([]byte(DefaultSSHTemplateData[t.Name]))
a := bytes.TrimSpace(b) if bytes.Equal(a, b) {
b := bytes.TrimSpace([]byte(DefaultSSHTemplateData[t.Name])) t.RequiredData = []string{"Certificate", "Key"}
if bytes.Equal(a, b) {
t.RequiredData = []string{"Certificate", "Key"}
}
} }
} }
} }