forked from TrueCloudLab/certificates
Merge pull request #720 from smallstep/max/lint
Introduce gocritic linter and address warnings
This commit is contained in:
commit
04fe3126be
92 changed files with 709 additions and 751 deletions
|
@ -36,22 +36,30 @@ linters-settings:
|
||||||
- performance
|
- performance
|
||||||
- style
|
- style
|
||||||
- experimental
|
- experimental
|
||||||
|
- diagnostic
|
||||||
disabled-checks:
|
disabled-checks:
|
||||||
- wrapperFunc
|
- commentFormatting
|
||||||
- dupImport # https://github.com/go-critic/go-critic/issues/845
|
- commentedOutCode
|
||||||
|
- evalOrder
|
||||||
|
- hugeParam
|
||||||
|
- octalLiteral
|
||||||
|
- rangeValCopy
|
||||||
|
- tooManyResultsChecker
|
||||||
|
- unnamedResult
|
||||||
|
|
||||||
linters:
|
linters:
|
||||||
disable-all: true
|
disable-all: true
|
||||||
enable:
|
enable:
|
||||||
- gofmt
|
|
||||||
- revive
|
|
||||||
- govet
|
|
||||||
- misspell
|
|
||||||
- ineffassign
|
|
||||||
- deadcode
|
- deadcode
|
||||||
|
- gocritic
|
||||||
|
- gofmt
|
||||||
|
- gosimple
|
||||||
|
- govet
|
||||||
|
- ineffassign
|
||||||
|
- misspell
|
||||||
|
- revive
|
||||||
- staticcheck
|
- staticcheck
|
||||||
- unused
|
- unused
|
||||||
- gosimple
|
|
||||||
|
|
||||||
run:
|
run:
|
||||||
skip-dirs:
|
skip-dirs:
|
||||||
|
|
|
@ -6,10 +6,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
||||||
|
|
||||||
## [Unreleased - 0.17.5] - DATE
|
## [Unreleased - 0.17.5] - DATE
|
||||||
### Added
|
### Added
|
||||||
|
- gocritic linter
|
||||||
### Changed
|
### Changed
|
||||||
### Deprecated
|
### Deprecated
|
||||||
### Removed
|
### Removed
|
||||||
### Fixed
|
### Fixed
|
||||||
|
- gocritic warnings
|
||||||
### Security
|
### Security
|
||||||
|
|
||||||
## [0.17.4] - 2021-09-28
|
## [0.17.4] - 2021-09-28
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -154,7 +154,7 @@ fmt:
|
||||||
$Q gofmt -l -w $(SRC)
|
$Q gofmt -l -w $(SRC)
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
$Q $(GOFLAGS) LOG_LEVEL=error golangci-lint run --timeout=30m
|
$Q golangci-lint run --timeout=30m
|
||||||
|
|
||||||
lintcgo:
|
lintcgo:
|
||||||
$Q LOG_LEVEL=error golangci-lint run --timeout=30m
|
$Q LOG_LEVEL=error golangci-lint run --timeout=30m
|
||||||
|
|
|
@ -19,7 +19,7 @@ type NewAccountRequest struct {
|
||||||
|
|
||||||
func validateContacts(cs []string) error {
|
func validateContacts(cs []string) error {
|
||||||
for _, c := range cs {
|
for _, c := range cs {
|
||||||
if len(c) == 0 {
|
if c == "" {
|
||||||
return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string")
|
return acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -178,7 +178,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
|
||||||
url := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID)
|
u := fmt.Sprintf("http://ca.smallstep.com/acme/%s/account/%s/orders", provName, accID)
|
||||||
|
|
||||||
oids := []string{"foo", "bar"}
|
oids := []string{"foo", "bar"}
|
||||||
oidURLs := []string{
|
oidURLs := []string{
|
||||||
|
@ -255,7 +255,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetOrdersByAccountID(w, req)
|
h.GetOrdersByAccountID(w, req)
|
||||||
|
|
|
@ -148,7 +148,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
||||||
// Request with chi context
|
// Request with chi context
|
||||||
chiCtx := chi.NewRouteContext()
|
chiCtx := chi.NewRouteContext()
|
||||||
chiCtx.URLParams.Add("authzID", az.ID)
|
chiCtx.URLParams.Add("authzID", az.ID)
|
||||||
url := fmt.Sprintf("%s/acme/%s/authz/%s",
|
u := fmt.Sprintf("%s/acme/%s/authz/%s",
|
||||||
baseURL.String(), provName, az.ID)
|
baseURL.String(), provName, az.ID)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
|
@ -280,7 +280,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
|
||||||
expB, err := json.Marshal(az)
|
expB, err := json.Marshal(az)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Location"], []string{url})
|
assert.Equals(t, res.Header["Location"], []string{u})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -314,7 +314,7 @@ func TestHandler_GetCertificate(t *testing.T) {
|
||||||
// Request with chi context
|
// Request with chi context
|
||||||
chiCtx := chi.NewRouteContext()
|
chiCtx := chi.NewRouteContext()
|
||||||
chiCtx.URLParams.Add("certID", certID)
|
chiCtx.URLParams.Add("certID", certID)
|
||||||
url := fmt.Sprintf("%s/acme/%s/certificate/%s",
|
u := fmt.Sprintf("%s/acme/%s/certificate/%s",
|
||||||
baseURL.String(), provName, certID)
|
baseURL.String(), provName, certID)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
|
@ -396,7 +396,7 @@ func TestHandler_GetCertificate(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
h := &Handler{db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetCertificate(w, req)
|
h.GetCertificate(w, req)
|
||||||
|
@ -434,7 +434,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
|
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
|
|
||||||
url := fmt.Sprintf("%s/acme/%s/challenge/%s/%s",
|
u := fmt.Sprintf("%s/acme/%s/challenge/%s/%s",
|
||||||
baseURL.String(), provName, "authzID", "chID")
|
baseURL.String(), provName, "authzID", "chID")
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
|
@ -635,7 +635,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
AuthorizationID: "authzID",
|
AuthorizationID: "authzID",
|
||||||
Type: acme.HTTP01,
|
Type: acme.HTTP01,
|
||||||
AccountID: "accID",
|
AccountID: "accID",
|
||||||
URL: url,
|
URL: u,
|
||||||
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
Error: acme.NewError(acme.ErrorConnectionType, "force"),
|
||||||
},
|
},
|
||||||
vco: &acme.ValidateChallengeOptions{
|
vco: &acme.ValidateChallengeOptions{
|
||||||
|
@ -652,7 +652,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
|
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetChallenge(w, req)
|
h.GetChallenge(w, req)
|
||||||
|
@ -678,7 +678,7 @@ func TestHandler_GetChallenge(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")})
|
assert.Equals(t, res.Header["Link"], []string{fmt.Sprintf("<%s/acme/%s/authz/%s>;rel=\"up\"", baseURL, provName, "authzID")})
|
||||||
assert.Equals(t, res.Header["Location"], []string{url})
|
assert.Equals(t, res.Header["Location"], []string{u})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -223,7 +223,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "jwk and kid are mutually exclusive"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if hdr.JSONWebKey == nil && len(hdr.KeyID) == 0 {
|
if hdr.JSONWebKey == nil && hdr.KeyID == "" {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "either jwk or kid must be defined in jws protected header"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -367,7 +367,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
|
||||||
api.WriteError(w, err)
|
api.WriteError(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(jwk.Algorithm) != 0 && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
if jwk.Algorithm != "" && jwk.Algorithm != jws.Signatures[0].Protected.Algorithm {
|
||||||
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
api.WriteError(w, acme.NewError(acme.ErrorMalformedType, "verifier and signature algorithm do not match"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,7 +108,7 @@ func TestHandler_baseURLFromRequest(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_addNonce(t *testing.T) {
|
func TestHandler_addNonce(t *testing.T) {
|
||||||
url := "https://ca.smallstep.com/acme/new-nonce"
|
u := "https://ca.smallstep.com/acme/new-nonce"
|
||||||
type test struct {
|
type test struct {
|
||||||
db acme.DB
|
db acme.DB
|
||||||
err *acme.Error
|
err *acme.Error
|
||||||
|
@ -141,7 +141,7 @@ func TestHandler_addNonce(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
h := &Handler{db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.addNonce(testNext)(w, req)
|
h.addNonce(testNext)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
@ -230,7 +230,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
escProvName := url.PathEscape(prov.GetName())
|
escProvName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
url := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
|
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
|
||||||
type test struct {
|
type test struct {
|
||||||
h Handler
|
h Handler
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
@ -245,7 +245,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
h: Handler{
|
h: Handler{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: NewLinker("dns", "acme"),
|
||||||
},
|
},
|
||||||
url: url,
|
url: u,
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
contentType: "foo",
|
contentType: "foo",
|
||||||
statusCode: 500,
|
statusCode: 500,
|
||||||
|
@ -257,7 +257,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
h: Handler{
|
h: Handler{
|
||||||
linker: NewLinker("dns", "acme"),
|
linker: NewLinker("dns", "acme"),
|
||||||
},
|
},
|
||||||
url: url,
|
url: u,
|
||||||
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
|
||||||
contentType: "foo",
|
contentType: "foo",
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
|
@ -319,11 +319,11 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
_url := url
|
_u := u
|
||||||
if tc.url != "" {
|
if tc.url != "" {
|
||||||
_url = tc.url
|
_u = tc.url
|
||||||
}
|
}
|
||||||
req := httptest.NewRequest("GET", _url, nil)
|
req := httptest.NewRequest("GET", _u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
req.Header.Add("Content-Type", tc.contentType)
|
req.Header.Add("Content-Type", tc.contentType)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -353,7 +353,7 @@ func TestHandler_verifyContentType(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_isPostAsGet(t *testing.T) {
|
func TestHandler_isPostAsGet(t *testing.T) {
|
||||||
url := "https://ca.smallstep.com/acme/new-account"
|
u := "https://ca.smallstep.com/acme/new-account"
|
||||||
type test struct {
|
type test struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
err *acme.Error
|
err *acme.Error
|
||||||
|
@ -392,7 +392,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.isPostAsGet(testNext)(w, req)
|
h.isPostAsGet(testNext)(w, req)
|
||||||
|
@ -430,7 +430,7 @@ func (errReader) Close() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_parseJWS(t *testing.T) {
|
func TestHandler_parseJWS(t *testing.T) {
|
||||||
url := "https://ca.smallstep.com/acme/new-account"
|
u := "https://ca.smallstep.com/acme/new-account"
|
||||||
type test struct {
|
type test struct {
|
||||||
next nextHTTP
|
next nextHTTP
|
||||||
body io.Reader
|
body io.Reader
|
||||||
|
@ -483,7 +483,7 @@ func TestHandler_parseJWS(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest("GET", url, tc.body)
|
req := httptest.NewRequest("GET", u, tc.body)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.parseJWS(tc.next)(w, req)
|
h.parseJWS(tc.next)(w, req)
|
||||||
res := w.Result()
|
res := w.Result()
|
||||||
|
@ -528,7 +528,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
parsedJWS, err := jose.ParseJWS(raw)
|
parsedJWS, err := jose.ParseJWS(raw)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
url := "https://ca.smallstep.com/acme/account/1234"
|
u := "https://ca.smallstep.com/acme/account/1234"
|
||||||
type test struct {
|
type test struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
next func(http.ResponseWriter, *http.Request)
|
next func(http.ResponseWriter, *http.Request)
|
||||||
|
@ -681,7 +681,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{}
|
h := &Handler{}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.verifyAndExtractJWSPayload(tc.next)(w, req)
|
h.verifyAndExtractJWSPayload(tc.next)(w, req)
|
||||||
|
@ -713,7 +713,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
provName := url.PathEscape(prov.GetName())
|
provName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
url := fmt.Sprintf("%s/acme/%s/account/1234",
|
u := fmt.Sprintf("%s/acme/%s/account/1234",
|
||||||
baseURL, provName)
|
baseURL, provName)
|
||||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
@ -883,7 +883,7 @@ func TestHandler_lookupJWK(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db, linker: tc.linker}
|
h := &Handler{db: tc.db, linker: tc.linker}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.lookupJWK(tc.next)(w, req)
|
h.lookupJWK(tc.next)(w, req)
|
||||||
|
@ -934,7 +934,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
parsedJWS, err := jose.ParseJWS(raw)
|
parsedJWS, err := jose.ParseJWS(raw)
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
url := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234",
|
u := fmt.Sprintf("https://ca.smallstep.com/acme/%s/account/1234",
|
||||||
provName)
|
provName)
|
||||||
type test struct {
|
type test struct {
|
||||||
db acme.DB
|
db acme.DB
|
||||||
|
@ -1079,7 +1079,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
h := &Handler{db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.extractJWK(tc.next)(w, req)
|
h.extractJWK(tc.next)(w, req)
|
||||||
|
@ -1108,7 +1108,7 @@ func TestHandler_extractJWK(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandler_validateJWS(t *testing.T) {
|
func TestHandler_validateJWS(t *testing.T) {
|
||||||
url := "https://ca.smallstep.com/acme/account/1234"
|
u := "https://ca.smallstep.com/acme/account/1234"
|
||||||
type test struct {
|
type test struct {
|
||||||
db acme.DB
|
db acme.DB
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
@ -1198,7 +1198,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Algorithm: jose.RS256,
|
Algorithm: jose.RS256,
|
||||||
JSONWebKey: &pub,
|
JSONWebKey: &pub,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1226,7 +1226,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Algorithm: jose.RS256,
|
Algorithm: jose.RS256,
|
||||||
JSONWebKey: &pub,
|
JSONWebKey: &pub,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1298,7 +1298,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
},
|
},
|
||||||
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
ctx: context.WithValue(context.Background(), jwsContextKey, jws),
|
||||||
statusCode: 400,
|
statusCode: 400,
|
||||||
err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", url),
|
err: acme.NewError(acme.ErrorMalformedType, "url header in JWS (foo) does not match request url (%s)", u),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"fail/both-jwk-kid": func(t *testing.T) test {
|
"fail/both-jwk-kid": func(t *testing.T) test {
|
||||||
|
@ -1313,7 +1313,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
KeyID: "bar",
|
KeyID: "bar",
|
||||||
JSONWebKey: &pub,
|
JSONWebKey: &pub,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1337,7 +1337,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Protected: jose.Header{
|
Protected: jose.Header{
|
||||||
Algorithm: jose.ES256,
|
Algorithm: jose.ES256,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1362,7 +1362,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Algorithm: jose.ES256,
|
Algorithm: jose.ES256,
|
||||||
KeyID: "bar",
|
KeyID: "bar",
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1392,7 +1392,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Algorithm: jose.ES256,
|
Algorithm: jose.ES256,
|
||||||
JSONWebKey: &pub,
|
JSONWebKey: &pub,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1422,7 +1422,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
Algorithm: jose.RS256,
|
Algorithm: jose.RS256,
|
||||||
JSONWebKey: &pub,
|
JSONWebKey: &pub,
|
||||||
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||||
"url": url,
|
"url": u,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -1446,7 +1446,7 @@ func TestHandler_validateJWS(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{db: tc.db}
|
h := &Handler{db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.validateJWS(tc.next)(w, req)
|
h.validateJWS(tc.next)(w, req)
|
||||||
|
|
|
@ -264,7 +264,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
// Request with chi context
|
// Request with chi context
|
||||||
chiCtx := chi.NewRouteContext()
|
chiCtx := chi.NewRouteContext()
|
||||||
chiCtx.URLParams.Add("ordID", o.ID)
|
chiCtx.URLParams.Add("ordID", o.ID)
|
||||||
url := fmt.Sprintf("%s/acme/%s/order/%s",
|
u := fmt.Sprintf("%s/acme/%s/order/%s",
|
||||||
baseURL.String(), escProvName, o.ID)
|
baseURL.String(), escProvName, o.ID)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
|
@ -422,7 +422,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.GetOrder(w, req)
|
h.GetOrder(w, req)
|
||||||
|
@ -448,7 +448,7 @@ func TestHandler_GetOrder(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
|
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Location"], []string{url})
|
assert.Equals(t, res.Header["Location"], []string{u})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -663,7 +663,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
prov := newProv()
|
prov := newProv()
|
||||||
escProvName := url.PathEscape(prov.GetName())
|
escProvName := url.PathEscape(prov.GetName())
|
||||||
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
|
||||||
url := fmt.Sprintf("%s/acme/%s/order/ordID",
|
u := fmt.Sprintf("%s/acme/%s/order/ordID",
|
||||||
baseURL.String(), escProvName)
|
baseURL.String(), escProvName)
|
||||||
|
|
||||||
type test struct {
|
type test struct {
|
||||||
|
@ -1335,7 +1335,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.NewOrder(w, req)
|
h.NewOrder(w, req)
|
||||||
|
@ -1363,7 +1363,7 @@ func TestHandler_NewOrder(t *testing.T) {
|
||||||
tc.vr(t, ro)
|
tc.vr(t, ro)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equals(t, res.Header["Location"], []string{url})
|
assert.Equals(t, res.Header["Location"], []string{u})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -1406,7 +1406,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
// Request with chi context
|
// Request with chi context
|
||||||
chiCtx := chi.NewRouteContext()
|
chiCtx := chi.NewRouteContext()
|
||||||
chiCtx.URLParams.Add("ordID", o.ID)
|
chiCtx.URLParams.Add("ordID", o.ID)
|
||||||
url := fmt.Sprintf("%s/acme/%s/order/%s",
|
u := fmt.Sprintf("%s/acme/%s/order/%s",
|
||||||
baseURL.String(), escProvName, o.ID)
|
baseURL.String(), escProvName, o.ID)
|
||||||
|
|
||||||
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
|
_csr, err := pemutil.Read("../../authority/testdata/certs/foo.csr")
|
||||||
|
@ -1625,7 +1625,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db}
|
||||||
req := httptest.NewRequest("GET", url, nil)
|
req := httptest.NewRequest("GET", u, nil)
|
||||||
req = req.WithContext(tc.ctx)
|
req = req.WithContext(tc.ctx)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
h.FinalizeOrder(w, req)
|
h.FinalizeOrder(w, req)
|
||||||
|
@ -1654,7 +1654,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
|
||||||
assert.FatalError(t, json.Unmarshal(body, ro))
|
assert.FatalError(t, json.Unmarshal(body, ro))
|
||||||
|
|
||||||
assert.Equals(t, bytes.TrimSpace(body), expB)
|
assert.Equals(t, bytes.TrimSpace(body), expB)
|
||||||
assert.Equals(t, res.Header["Location"], []string{url})
|
assert.Equals(t, res.Header["Location"], []string{u})
|
||||||
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
assert.Equals(t, res.Header["Content-Type"], []string{"application/json"})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -76,23 +76,23 @@ func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error {
|
||||||
url := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
|
||||||
|
|
||||||
resp, err := vo.HTTPGet(url.String())
|
resp, err := vo.HTTPGet(u.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
|
||||||
"error doing http GET for url %s", url))
|
"error doing http GET for url %s", u))
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
return storeError(ctx, db, ch, false, NewError(ErrorConnectionType,
|
return storeError(ctx, db, ch, false, NewError(ErrorConnectionType,
|
||||||
"error doing http GET for url %s with status code %d", url, resp.StatusCode))
|
"error doing http GET for url %s with status code %d", u, resp.StatusCode))
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WrapErrorISE(err, "error reading "+
|
return WrapErrorISE(err, "error reading "+
|
||||||
"response body for url %s", url)
|
"response body for url %s", u)
|
||||||
}
|
}
|
||||||
keyAuth := strings.TrimSpace(string(body))
|
keyAuth := strings.TrimSpace(string(body))
|
||||||
|
|
||||||
|
|
|
@ -1276,7 +1276,7 @@ func newTLSALPNValidationCert(keyAuthHash []byte, obsoleteOID, critical bool, na
|
||||||
oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
|
oid = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 30, 1}
|
||||||
}
|
}
|
||||||
|
|
||||||
keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash[:])
|
keyAuthHashEnc, _ := asn1.Marshal(keyAuthHash)
|
||||||
|
|
||||||
certTemplate.ExtraExtensions = []pkix.Extension{
|
certTemplate.ExtraExtensions = []pkix.Extension{
|
||||||
{
|
{
|
||||||
|
|
|
@ -93,8 +93,8 @@ func TestDB_getDBAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if dbacc, err := db.getDBAccount(context.Background(), accID); err != nil {
|
if dbacc, err := d.getDBAccount(context.Background(), accID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -109,8 +109,7 @@ func TestDB_getDBAccount(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, dbacc.ID, tc.dbacc.ID)
|
assert.Equals(t, dbacc.ID, tc.dbacc.ID)
|
||||||
assert.Equals(t, dbacc.Status, tc.dbacc.Status)
|
assert.Equals(t, dbacc.Status, tc.dbacc.Status)
|
||||||
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
|
assert.Equals(t, dbacc.CreatedAt, tc.dbacc.CreatedAt)
|
||||||
|
@ -118,7 +117,6 @@ func TestDB_getDBAccount(t *testing.T) {
|
||||||
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
|
assert.Equals(t, dbacc.Contact, tc.dbacc.Contact)
|
||||||
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
|
assert.Equals(t, dbacc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -174,8 +172,8 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if retAccID, err := db.getAccountIDByKeyID(context.Background(), kid); err != nil {
|
if retAccID, err := d.getAccountIDByKeyID(context.Background(), kid); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -190,11 +188,9 @@ func TestDB_getAccountIDByKeyID(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, retAccID, accID)
|
assert.Equals(t, retAccID, accID)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -250,8 +246,8 @@ func TestDB_GetAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if acc, err := db.GetAccount(context.Background(), accID); err != nil {
|
if acc, err := d.GetAccount(context.Background(), accID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -266,14 +262,12 @@ func TestDB_GetAccount(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -358,8 +352,8 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if acc, err := db.GetAccountByKeyID(context.Background(), kid); err != nil {
|
if acc, err := d.GetAccountByKeyID(context.Background(), kid); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -374,14 +368,12 @@ func TestDB_GetAccountByKeyID(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
assert.Equals(t, acc.ID, tc.dbacc.ID)
|
||||||
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
assert.Equals(t, acc.Status, tc.dbacc.Status)
|
||||||
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
assert.Equals(t, acc.Contact, tc.dbacc.Contact)
|
||||||
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
assert.Equals(t, acc.Key.KeyID, tc.dbacc.Key.KeyID)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -527,8 +519,8 @@ func TestDB_CreateAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.CreateAccount(context.Background(), tc.acc); err != nil {
|
if err := d.CreateAccount(context.Background(), tc.acc); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -688,8 +680,8 @@ func TestDB_UpdateAccount(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.UpdateAccount(context.Background(), tc.acc); err != nil {
|
if err := d.UpdateAccount(context.Background(), tc.acc); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,8 +97,8 @@ func TestDB_getDBAuthz(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if dbaz, err := db.getDBAuthz(context.Background(), azID); err != nil {
|
if dbaz, err := d.getDBAuthz(context.Background(), azID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -113,8 +113,7 @@ func TestDB_getDBAuthz(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, dbaz.ID, tc.dbaz.ID)
|
assert.Equals(t, dbaz.ID, tc.dbaz.ID)
|
||||||
assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
|
assert.Equals(t, dbaz.AccountID, tc.dbaz.AccountID)
|
||||||
assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
|
assert.Equals(t, dbaz.Identifier, tc.dbaz.Identifier)
|
||||||
|
@ -125,7 +124,6 @@ func TestDB_getDBAuthz(t *testing.T) {
|
||||||
assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
|
assert.Equals(t, dbaz.Error.Error(), tc.dbaz.Error.Error())
|
||||||
assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
|
assert.Equals(t, dbaz.Wildcard, tc.dbaz.Wildcard)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -293,8 +291,8 @@ func TestDB_GetAuthorization(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if az, err := db.GetAuthorization(context.Background(), azID); err != nil {
|
if az, err := d.GetAuthorization(context.Background(), azID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -309,8 +307,7 @@ func TestDB_GetAuthorization(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, az.ID, tc.dbaz.ID)
|
assert.Equals(t, az.ID, tc.dbaz.ID)
|
||||||
assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
|
assert.Equals(t, az.AccountID, tc.dbaz.AccountID)
|
||||||
assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
|
assert.Equals(t, az.Identifier, tc.dbaz.Identifier)
|
||||||
|
@ -324,7 +321,6 @@ func TestDB_GetAuthorization(t *testing.T) {
|
||||||
})
|
})
|
||||||
assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
|
assert.Equals(t, az.Error.Error(), tc.dbaz.Error.Error())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -445,8 +441,8 @@ func TestDB_CreateAuthorization(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.CreateAuthorization(context.Background(), tc.az); err != nil {
|
if err := d.CreateAuthorization(context.Background(), tc.az); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -594,8 +590,8 @@ func TestDB_UpdateAuthorization(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.UpdateAuthorization(context.Background(), tc.az); err != nil {
|
if err := d.UpdateAuthorization(context.Background(), tc.az); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,8 +98,8 @@ func TestDB_CreateCertificate(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.CreateCertificate(context.Background(), tc.cert); err != nil {
|
if err := d.CreateCertificate(context.Background(), tc.cert); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -228,8 +228,8 @@ func TestDB_GetCertificate(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
cert, err := db.GetCertificate(context.Background(), certID)
|
cert, err := d.GetCertificate(context.Background(), certID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
|
@ -245,15 +245,13 @@ func TestDB_GetCertificate(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, cert.ID, certID)
|
assert.Equals(t, cert.ID, certID)
|
||||||
assert.Equals(t, cert.AccountID, "accountID")
|
assert.Equals(t, cert.AccountID, "accountID")
|
||||||
assert.Equals(t, cert.OrderID, "orderID")
|
assert.Equals(t, cert.OrderID, "orderID")
|
||||||
assert.Equals(t, cert.Leaf, leaf)
|
assert.Equals(t, cert.Leaf, leaf)
|
||||||
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
assert.Equals(t, cert.Intermediates, []*x509.Certificate{inter, root})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -92,8 +92,8 @@ func TestDB_getDBChallenge(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if ch, err := db.getDBChallenge(context.Background(), chID); err != nil {
|
if ch, err := d.getDBChallenge(context.Background(), chID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -108,8 +108,7 @@ func TestDB_getDBChallenge(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||||
|
@ -119,7 +118,6 @@ func TestDB_getDBChallenge(t *testing.T) {
|
||||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -206,8 +204,8 @@ func TestDB_CreateChallenge(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.CreateChallenge(context.Background(), tc.ch); err != nil {
|
if err := d.CreateChallenge(context.Background(), tc.ch); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -286,8 +284,8 @@ func TestDB_GetChallenge(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if ch, err := db.GetChallenge(context.Background(), chID, azID); err != nil {
|
if ch, err := d.GetChallenge(context.Background(), chID, azID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -302,8 +300,7 @@ func TestDB_GetChallenge(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, ch.ID, tc.dbc.ID)
|
assert.Equals(t, ch.ID, tc.dbc.ID)
|
||||||
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
assert.Equals(t, ch.AccountID, tc.dbc.AccountID)
|
||||||
assert.Equals(t, ch.Type, tc.dbc.Type)
|
assert.Equals(t, ch.Type, tc.dbc.Type)
|
||||||
|
@ -313,7 +310,6 @@ func TestDB_GetChallenge(t *testing.T) {
|
||||||
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
assert.Equals(t, ch.ValidatedAt, tc.dbc.ValidatedAt)
|
||||||
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
assert.Equals(t, ch.Error.Error(), tc.dbc.Error.Error())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -442,8 +438,8 @@ func TestDB_UpdateChallenge(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.UpdateChallenge(context.Background(), tc.ch); err != nil {
|
if err := d.UpdateChallenge(context.Background(), tc.ch); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ func (db *DB) CreateNonce(ctx context.Context) (acme.Nonce, error) {
|
||||||
ID: id,
|
ID: id,
|
||||||
CreatedAt: clock.Now(),
|
CreatedAt: clock.Now(),
|
||||||
}
|
}
|
||||||
if err = db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
|
if err := db.save(ctx, id, n, nil, "nonce", nonceTable); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return acme.Nonce(id), nil
|
return acme.Nonce(id), nil
|
||||||
|
|
|
@ -67,8 +67,8 @@ func TestDB_CreateNonce(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if n, err := db.CreateNonce(context.Background()); err != nil {
|
if n, err := d.CreateNonce(context.Background()); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -144,8 +144,8 @@ func TestDB_DeleteNonce(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
|
if err := d.DeleteNonce(context.Background(), acme.Nonce(nonceID)); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
|
|
@ -41,7 +41,7 @@ func New(db nosqlDB.DB) (*DB, error) {
|
||||||
|
|
||||||
// save writes the new data to the database, overwriting the old data if it
|
// save writes the new data to the database, overwriting the old data if it
|
||||||
// existed.
|
// existed.
|
||||||
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error {
|
func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
newB []byte
|
newB []byte
|
||||||
|
|
|
@ -126,8 +126,8 @@ func TestDB_save(t *testing.T) {
|
||||||
}
|
}
|
||||||
for name, tc := range tests {
|
for name, tc := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := &DB{db: tc.db}
|
d := &DB{db: tc.db}
|
||||||
if err := db.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
|
if err := d.save(context.Background(), "id", tc.nu, tc.old, "challenge", challengeTable); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,6 @@ import (
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"github.com/smallstep/nosql"
|
"github.com/smallstep/nosql"
|
||||||
"github.com/smallstep/nosql/database"
|
"github.com/smallstep/nosql/database"
|
||||||
nosqldb "github.com/smallstep/nosql/database"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDB_getDBOrder(t *testing.T) {
|
func TestDB_getDBOrder(t *testing.T) {
|
||||||
|
@ -32,7 +31,7 @@ func TestDB_getDBOrder(t *testing.T) {
|
||||||
assert.Equals(t, bucket, orderTable)
|
assert.Equals(t, bucket, orderTable)
|
||||||
assert.Equals(t, string(key), orderID)
|
assert.Equals(t, string(key), orderID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
|
||||||
|
@ -101,8 +100,8 @@ func TestDB_getDBOrder(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if dbo, err := db.getDBOrder(context.Background(), orderID); err != nil {
|
if dbo, err := d.getDBOrder(context.Background(), orderID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -117,8 +116,7 @@ func TestDB_getDBOrder(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, dbo.ID, tc.dbo.ID)
|
assert.Equals(t, dbo.ID, tc.dbo.ID)
|
||||||
assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID)
|
assert.Equals(t, dbo.ProvisionerID, tc.dbo.ProvisionerID)
|
||||||
assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID)
|
assert.Equals(t, dbo.CertificateID, tc.dbo.CertificateID)
|
||||||
|
@ -131,7 +129,6 @@ func TestDB_getDBOrder(t *testing.T) {
|
||||||
assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
assert.Equals(t, dbo.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
||||||
assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error())
|
assert.Equals(t, dbo.Error.Error(), tc.dbo.Error.Error())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -165,7 +162,7 @@ func TestDB_GetOrder(t *testing.T) {
|
||||||
assert.Equals(t, bucket, orderTable)
|
assert.Equals(t, bucket, orderTable)
|
||||||
assert.Equals(t, string(key), orderID)
|
assert.Equals(t, string(key), orderID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
|
acmeErr: acme.NewError(acme.ErrorMalformedType, "order orderID not found"),
|
||||||
|
@ -207,8 +204,8 @@ func TestDB_GetOrder(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if o, err := db.GetOrder(context.Background(), orderID); err != nil {
|
if o, err := d.GetOrder(context.Background(), orderID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *acme.Error:
|
case *acme.Error:
|
||||||
if assert.NotNil(t, tc.acmeErr) {
|
if assert.NotNil(t, tc.acmeErr) {
|
||||||
|
@ -223,8 +220,7 @@ func TestDB_GetOrder(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, o.ID, tc.dbo.ID)
|
assert.Equals(t, o.ID, tc.dbo.ID)
|
||||||
assert.Equals(t, o.AccountID, tc.dbo.AccountID)
|
assert.Equals(t, o.AccountID, tc.dbo.AccountID)
|
||||||
assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID)
|
assert.Equals(t, o.ProvisionerID, tc.dbo.ProvisionerID)
|
||||||
|
@ -237,7 +233,6 @@ func TestDB_GetOrder(t *testing.T) {
|
||||||
assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
assert.Equals(t, o.AuthorizationIDs, tc.dbo.AuthorizationIDs)
|
||||||
assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error())
|
assert.Equals(t, o.Error.Error(), tc.dbo.Error.Error())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -367,8 +362,8 @@ func TestDB_UpdateOrder(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.UpdateOrder(context.Background(), tc.o); err != nil {
|
if err := d.UpdateOrder(context.Background(), tc.o); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -512,7 +507,7 @@ func TestDB_CreateOrder(t *testing.T) {
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
assert.Equals(t, string(bucket), string(ordersByAccountIDTable))
|
assert.Equals(t, string(bucket), string(ordersByAccountIDTable))
|
||||||
assert.Equals(t, string(key), o.AccountID)
|
assert.Equals(t, string(key), o.AccountID)
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
switch string(bucket) {
|
switch string(bucket) {
|
||||||
|
@ -558,8 +553,8 @@ func TestDB_CreateOrder(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if err := db.CreateOrder(context.Background(), tc.o); err != nil {
|
if err := d.CreateOrder(context.Background(), tc.o); err != nil {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
|
@ -681,7 +676,7 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
||||||
MGet: func(bucket, key []byte) ([]byte, error) {
|
MGet: func(bucket, key []byte) ([]byte, error) {
|
||||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||||
assert.Equals(t, key, []byte(accID))
|
assert.Equals(t, key, []byte(accID))
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
MCmpAndSwap: func(bucket, key, old, nu []byte) ([]byte, bool, error) {
|
||||||
assert.Equals(t, bucket, ordersByAccountIDTable)
|
assert.Equals(t, bucket, ordersByAccountIDTable)
|
||||||
|
@ -996,15 +991,15 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
var (
|
var (
|
||||||
res []string
|
res []string
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if tc.addOids == nil {
|
if tc.addOids == nil {
|
||||||
res, err = db.updateAddOrderIDs(context.Background(), accID)
|
res, err = d.updateAddOrderIDs(context.Background(), accID)
|
||||||
} else {
|
} else {
|
||||||
res, err = db.updateAddOrderIDs(context.Background(), accID, tc.addOids...)
|
res, err = d.updateAddOrderIDs(context.Background(), accID, tc.addOids...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1022,11 +1017,9 @@ func TestDB_updateAddOrderIDs(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.True(t, reflect.DeepEqual(res, tc.res))
|
assert.True(t, reflect.DeepEqual(res, tc.res))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -289,6 +289,7 @@ func canonicalize(csr *x509.CertificateRequest) (canonicalized *x509.Certificate
|
||||||
// name or in an extensionRequest attribute [RFC2985] requesting a
|
// name or in an extensionRequest attribute [RFC2985] requesting a
|
||||||
// subjectAltName extension, or both.
|
// subjectAltName extension, or both.
|
||||||
if csr.Subject.CommonName != "" {
|
if csr.Subject.CommonName != "" {
|
||||||
|
// nolint:gocritic
|
||||||
canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
|
canonicalized.DNSNames = append(csr.DNSNames, csr.Subject.CommonName)
|
||||||
}
|
}
|
||||||
canonicalized.DNSNames = uniqueSortedLowerNames(csr.DNSNames)
|
canonicalized.DNSNames = uniqueSortedLowerNames(csr.DNSNames)
|
||||||
|
|
11
api/api.go
11
api/api.go
|
@ -240,9 +240,9 @@ type caHandler struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new RouterHandler with the CA endpoints.
|
// New creates a new RouterHandler with the CA endpoints.
|
||||||
func New(authority Authority) RouterHandler {
|
func New(auth Authority) RouterHandler {
|
||||||
return &caHandler{
|
return &caHandler{
|
||||||
Authority: authority,
|
Authority: auth,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -295,7 +295,7 @@ func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) {
|
||||||
// certificate for the given SHA256.
|
// certificate for the given SHA256.
|
||||||
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) {
|
||||||
sha := chi.URLParam(r, "sha")
|
sha := chi.URLParam(r, "sha")
|
||||||
sum := strings.ToLower(strings.Replace(sha, "-", "", -1))
|
sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
|
||||||
// Load root certificate with the
|
// Load root certificate with the
|
||||||
cert, err := h.Authority.Root(sum)
|
cert, err := h.Authority.Root(sum)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -409,7 +409,9 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
|
||||||
"certificate": base64.StdEncoding.EncodeToString(cert.Raw),
|
"certificate": base64.StdEncoding.EncodeToString(cert.Raw),
|
||||||
}
|
}
|
||||||
for _, ext := range cert.Extensions {
|
for _, ext := range cert.Extensions {
|
||||||
if ext.Id.Equal(oidStepProvisioner) {
|
if !ext.Id.Equal(oidStepProvisioner) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
val := &stepProvisioner{}
|
val := &stepProvisioner{}
|
||||||
rest, err := asn1.Unmarshal(ext.Value, val)
|
rest, err := asn1.Unmarshal(ext.Value, val)
|
||||||
if err != nil || len(rest) > 0 {
|
if err != nil || len(rest) > 0 {
|
||||||
|
@ -422,7 +424,6 @@ func LogCertificate(w http.ResponseWriter, cert *x509.Certificate) {
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
|
||||||
rl.WithFields(m)
|
rl.WithFields(m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -186,8 +186,8 @@ func TestCertificate_MarshalJSON(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"nil", fields{Certificate: nil}, []byte("null"), false},
|
{"nil", fields{Certificate: nil}, []byte("null"), false},
|
||||||
{"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false},
|
{"empty", fields{Certificate: &x509.Certificate{Raw: nil}}, []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n"`), false},
|
||||||
{"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"`), false},
|
{"root", fields{Certificate: parseCertificate(rootPEM)}, []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"`), false},
|
||||||
{"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`), false},
|
{"cert", fields{Certificate: parseCertificate(certPEM)}, []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`), false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -219,11 +219,11 @@ func TestCertificate_UnmarshalJSON(t *testing.T) {
|
||||||
{"invalid string", []byte(`"foobar"`), false, true},
|
{"invalid string", []byte(`"foobar"`), false, true},
|
||||||
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
|
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
|
||||||
{"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true},
|
{"empty csr", []byte(`"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"`), false, true},
|
||||||
{"invalid type", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), false, true},
|
{"invalid type", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), false, true},
|
||||||
{"empty string", []byte(`""`), false, false},
|
{"empty string", []byte(`""`), false, false},
|
||||||
{"json null", []byte(`null`), false, false},
|
{"json null", []byte(`null`), false, false},
|
||||||
{"valid root", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), true, false},
|
{"valid root", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), true, false},
|
||||||
{"valid cert", []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"`), true, false},
|
{"valid cert", []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"`), true, false},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -251,7 +251,7 @@ func TestCertificate_UnmarshalJSON_json(t *testing.T) {
|
||||||
{"empty crt (null)", `{"crt":null}`, false, false},
|
{"empty crt (null)", `{"crt":null}`, false, false},
|
||||||
{"empty crt (string)", `{"crt":""}`, false, false},
|
{"empty crt (string)", `{"crt":""}`, false, false},
|
||||||
{"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true},
|
{"empty crt", `{"crt":"-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE----\n"}`, false, true},
|
||||||
{"valid crt", `{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `"}`, true, false},
|
{"valid crt", `{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `"}`, true, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
type request struct {
|
type request struct {
|
||||||
|
@ -297,7 +297,7 @@ func TestCertificateRequest_MarshalJSON(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{"nil", fields{CertificateRequest: nil}, []byte("null"), false},
|
{"nil", fields{CertificateRequest: nil}, []byte("null"), false},
|
||||||
{"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false},
|
{"empty", fields{CertificateRequest: &x509.CertificateRequest{}}, []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST-----\n"`), false},
|
||||||
{"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `\n"`), false},
|
{"csr", fields{CertificateRequest: parseCertificateRequest(csrPEM)}, []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `\n"`), false},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -329,10 +329,10 @@ func TestCertificateRequest_UnmarshalJSON(t *testing.T) {
|
||||||
{"invalid string", []byte(`"foobar"`), false, true},
|
{"invalid string", []byte(`"foobar"`), false, true},
|
||||||
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
|
{"invalid bytes 0", []byte{}, false, true}, {"invalid bytes 1", []byte{1}, false, true},
|
||||||
{"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true},
|
{"empty csr", []byte(`"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"`), false, true},
|
||||||
{"invalid type", []byte(`"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `"`), false, true},
|
{"invalid type", []byte(`"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `"`), false, true},
|
||||||
{"empty string", []byte(`""`), false, false},
|
{"empty string", []byte(`""`), false, false},
|
||||||
{"json null", []byte(`null`), false, false},
|
{"json null", []byte(`null`), false, false},
|
||||||
{"valid csr", []byte(`"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"`), true, false},
|
{"valid csr", []byte(`"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"`), true, false},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -360,7 +360,7 @@ func TestCertificateRequest_UnmarshalJSON_json(t *testing.T) {
|
||||||
{"empty csr (null)", `{"csr":null}`, false, false},
|
{"empty csr (null)", `{"csr":null}`, false, false},
|
||||||
{"empty csr (string)", `{"csr":""}`, false, false},
|
{"empty csr (string)", `{"csr":""}`, false, false},
|
||||||
{"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true},
|
{"empty csr", `{"csr":"-----BEGIN CERTIFICATE REQUEST-----\n-----END CERTIFICATE REQUEST----\n"}`, false, true},
|
||||||
{"valid csr", `{"csr":"` + strings.Replace(csrPEM, "\n", `\n`, -1) + `"}`, true, false},
|
{"valid csr", `{"csr":"` + strings.ReplaceAll(csrPEM, "\n", `\n`) + `"}`, true, false},
|
||||||
}
|
}
|
||||||
|
|
||||||
type request struct {
|
type request struct {
|
||||||
|
@ -739,7 +739,7 @@ func (m *mockAuthority) CheckSSHHost(ctx context.Context, principal, token strin
|
||||||
return m.ret1.(bool), m.err
|
return m.ret1.(bool), m.err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user string, hostname string) (*authority.Bastion, error) {
|
func (m *mockAuthority) GetSSHBastion(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
|
||||||
if m.getSSHBastion != nil {
|
if m.getSSHBastion != nil {
|
||||||
return m.getSSHBastion(ctx, user, hostname)
|
return m.getSSHBastion(ctx, user, hostname)
|
||||||
}
|
}
|
||||||
|
@ -816,7 +816,7 @@ func Test_caHandler_Root(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil)
|
req := httptest.NewRequest("GET", "http://example.com/root/efc7d6b475a56fe587650bcdb999a4a308f815ba44db4bf0371ea68a786ccd36", nil)
|
||||||
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
|
req = req.WithContext(context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx))
|
||||||
|
|
||||||
expected := []byte(`{"ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"}`)
|
expected := []byte(`{"ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"}`)
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -860,8 +860,8 @@ func Test_caHandler_Sign(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expected1 := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
expected2 := []byte(`{"crt":"` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(stepCertPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -934,7 +934,7 @@ func Test_caHandler_Renew(t *testing.T) {
|
||||||
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
|
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -995,7 +995,7 @@ func Test_caHandler_Rekey(t *testing.T) {
|
||||||
{"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest},
|
{"json read error", "{", cs, nil, nil, nil, http.StatusBadRequest},
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []byte(`{"crt":"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","ca":"` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n","certChain":["` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n","` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -1210,7 +1210,7 @@ func Test_caHandler_Roots(t *testing.T) {
|
||||||
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -1256,7 +1256,7 @@ func Test_caHandler_Federation(t *testing.T) {
|
||||||
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
{"fail", cs, nil, nil, fmt.Errorf("an error"), http.StatusForbidden},
|
||||||
}
|
}
|
||||||
|
|
||||||
expected := []byte(`{"crts":["` + strings.Replace(rootPEM, "\n", `\n`, -1) + `\n"]}`)
|
expected := []byte(`{"crts":["` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|
|
@ -50,15 +50,13 @@ func WriteError(w http.ResponseWriter, err error) {
|
||||||
rl.WithFields(map[string]interface{}{
|
rl.WithFields(map[string]interface{}{
|
||||||
"stack-trace": fmt.Sprintf("%+v", e),
|
"stack-trace": fmt.Sprintf("%+v", e),
|
||||||
})
|
})
|
||||||
} else {
|
} else if e, ok := cause.(errs.StackTracer); ok {
|
||||||
if e, ok := cause.(errs.StackTracer); ok {
|
|
||||||
rl.WithFields(map[string]interface{}{
|
rl.WithFields(map[string]interface{}{
|
||||||
"stack-trace": fmt.Sprintf("%+v", e),
|
"stack-trace": fmt.Sprintf("%+v", e),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(err); err != nil {
|
if err := json.NewEncoder(w).Encode(err); err != nil {
|
||||||
LogError(w, err)
|
LogError(w, err)
|
||||||
|
|
10
api/ssh.go
10
api/ssh.go
|
@ -52,7 +52,7 @@ func (s *SSHSignRequest) Validate() error {
|
||||||
return errors.Errorf("unknown certType %s", s.CertType)
|
return errors.Errorf("unknown certType %s", s.CertType)
|
||||||
case len(s.PublicKey) == 0:
|
case len(s.PublicKey) == 0:
|
||||||
return errors.New("missing or empty publicKey")
|
return errors.New("missing or empty publicKey")
|
||||||
case len(s.OTT) == 0:
|
case s.OTT == "":
|
||||||
return errors.New("missing or empty ott")
|
return errors.New("missing or empty ott")
|
||||||
default:
|
default:
|
||||||
// Validate identity signature if provided
|
// Validate identity signature if provided
|
||||||
|
@ -408,18 +408,18 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var config SSHConfigResponse
|
var cfg SSHConfigResponse
|
||||||
switch body.Type {
|
switch body.Type {
|
||||||
case provisioner.SSHUserCert:
|
case provisioner.SSHUserCert:
|
||||||
config.UserTemplates = ts
|
cfg.UserTemplates = ts
|
||||||
case provisioner.SSHHostCert:
|
case provisioner.SSHHostCert:
|
||||||
config.HostTemplates = ts
|
cfg.HostTemplates = ts
|
||||||
default:
|
default:
|
||||||
WriteError(w, errs.InternalServer("it should hot get here"))
|
WriteError(w, errs.InternalServer("it should hot get here"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
JSON(w, config)
|
JSON(w, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||||
|
|
|
@ -19,7 +19,7 @@ type SSHRekeyRequest struct {
|
||||||
// Validate validates the SSHSignRekey.
|
// Validate validates the SSHSignRekey.
|
||||||
func (s *SSHRekeyRequest) Validate() error {
|
func (s *SSHRekeyRequest) Validate() error {
|
||||||
switch {
|
switch {
|
||||||
case len(s.OTT) == 0:
|
case s.OTT == "":
|
||||||
return errors.New("missing or empty ott")
|
return errors.New("missing or empty ott")
|
||||||
case len(s.PublicKey) == 0:
|
case len(s.PublicKey) == 0:
|
||||||
return errors.New("missing or empty public key")
|
return errors.New("missing or empty public key")
|
||||||
|
|
|
@ -18,7 +18,7 @@ type SSHRenewRequest struct {
|
||||||
// Validate validates the SSHSignRequest.
|
// Validate validates the SSHSignRequest.
|
||||||
func (s *SSHRenewRequest) Validate() error {
|
func (s *SSHRenewRequest) Validate() error {
|
||||||
switch {
|
switch {
|
||||||
case len(s.OTT) == 0:
|
case s.OTT == "":
|
||||||
return errors.New("missing or empty ott")
|
return errors.New("missing or empty ott")
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -36,7 +36,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
||||||
if !r.Passive {
|
if !r.Passive {
|
||||||
return errs.NotImplemented("non-passive revocation not implemented")
|
return errs.NotImplemented("non-passive revocation not implemented")
|
||||||
}
|
}
|
||||||
if len(r.OTT) == 0 {
|
if r.OTT == "" {
|
||||||
return errs.BadRequest("missing ott")
|
return errs.BadRequest("missing ott")
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
|
@ -284,7 +284,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
|
||||||
identityCerts := []*x509.Certificate{
|
identityCerts := []*x509.Certificate{
|
||||||
parseCertificate(certPEM),
|
parseCertificate(certPEM),
|
||||||
}
|
}
|
||||||
identityCertsPEM := []byte(`"` + strings.Replace(certPEM, "\n", `\n`, -1) + `\n"`)
|
identityCertsPEM := []byte(`"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n"`)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
@ -27,7 +27,7 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
|
||||||
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
tok := r.Header.Get("Authorization")
|
tok := r.Header.Get("Authorization")
|
||||||
if len(tok) == 0 {
|
if tok == "" {
|
||||||
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
|
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
|
||||||
"missing authorization header token"))
|
"missing authorization header token"))
|
||||||
return
|
return
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"github.com/smallstep/nosql"
|
"github.com/smallstep/nosql"
|
||||||
"github.com/smallstep/nosql/database"
|
"github.com/smallstep/nosql/database"
|
||||||
nosqldb "github.com/smallstep/nosql/database"
|
|
||||||
"go.step.sm/linkedca"
|
"go.step.sm/linkedca"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
)
|
)
|
||||||
|
@ -32,7 +31,7 @@ func TestDB_getDBAdminBytes(t *testing.T) {
|
||||||
assert.Equals(t, bucket, adminsTable)
|
assert.Equals(t, bucket, adminsTable)
|
||||||
assert.Equals(t, string(key), adminID)
|
assert.Equals(t, string(key), adminID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
||||||
|
@ -67,8 +66,8 @@ func TestDB_getDBAdminBytes(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if b, err := db.getDBAdminBytes(context.Background(), adminID); err != nil {
|
if b, err := d.getDBAdminBytes(context.Background(), adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -83,11 +82,9 @@ func TestDB_getDBAdminBytes(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.Equals(t, string(b), "foo")
|
assert.Equals(t, string(b), "foo")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -108,7 +105,7 @@ func TestDB_getDBAdmin(t *testing.T) {
|
||||||
assert.Equals(t, bucket, adminsTable)
|
assert.Equals(t, bucket, adminsTable)
|
||||||
assert.Equals(t, string(key), adminID)
|
assert.Equals(t, string(key), adminID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
||||||
|
@ -193,8 +190,8 @@ func TestDB_getDBAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if dba, err := db.getDBAdmin(context.Background(), adminID); err != nil {
|
if dba, err := d.getDBAdmin(context.Background(), adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -209,8 +206,7 @@ func TestDB_getDBAdmin(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
|
||||||
assert.Equals(t, dba.ID, adminID)
|
assert.Equals(t, dba.ID, adminID)
|
||||||
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
|
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
|
||||||
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
|
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
|
||||||
|
@ -219,7 +215,6 @@ func TestDB_getDBAdmin(t *testing.T) {
|
||||||
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
|
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
|
||||||
assert.Fatal(t, dba.DeletedAt.IsZero())
|
assert.Fatal(t, dba.DeletedAt.IsZero())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -283,8 +278,8 @@ func TestDB_unmarshalDBAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{authorityID: admin.DefaultAuthorityID}
|
d := DB{authorityID: admin.DefaultAuthorityID}
|
||||||
if dba, err := db.unmarshalDBAdmin(tc.in, adminID); err != nil {
|
if dba, err := d.unmarshalDBAdmin(tc.in, adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -299,8 +294,7 @@ func TestDB_unmarshalDBAdmin(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
|
||||||
assert.Equals(t, dba.ID, adminID)
|
assert.Equals(t, dba.ID, adminID)
|
||||||
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
|
assert.Equals(t, dba.AuthorityID, tc.dba.AuthorityID)
|
||||||
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
|
assert.Equals(t, dba.ProvisionerID, tc.dba.ProvisionerID)
|
||||||
|
@ -309,7 +303,6 @@ func TestDB_unmarshalDBAdmin(t *testing.T) {
|
||||||
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
|
assert.Equals(t, dba.CreatedAt, tc.dba.CreatedAt)
|
||||||
assert.Fatal(t, dba.DeletedAt.IsZero())
|
assert.Fatal(t, dba.DeletedAt.IsZero())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -360,8 +353,8 @@ func TestDB_unmarshalAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{authorityID: admin.DefaultAuthorityID}
|
d := DB{authorityID: admin.DefaultAuthorityID}
|
||||||
if adm, err := db.unmarshalAdmin(tc.in, adminID); err != nil {
|
if adm, err := d.unmarshalAdmin(tc.in, adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -376,8 +369,7 @@ func TestDB_unmarshalAdmin(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
|
||||||
assert.Equals(t, adm.Id, adminID)
|
assert.Equals(t, adm.Id, adminID)
|
||||||
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
|
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
|
||||||
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
|
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
|
||||||
|
@ -386,7 +378,6 @@ func TestDB_unmarshalAdmin(t *testing.T) {
|
||||||
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
|
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
|
||||||
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
|
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -407,7 +398,7 @@ func TestDB_GetAdmin(t *testing.T) {
|
||||||
assert.Equals(t, bucket, adminsTable)
|
assert.Equals(t, bucket, adminsTable)
|
||||||
assert.Equals(t, string(key), adminID)
|
assert.Equals(t, string(key), adminID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
||||||
|
@ -516,8 +507,8 @@ func TestDB_GetAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if adm, err := db.GetAdmin(context.Background(), adminID); err != nil {
|
if adm, err := d.GetAdmin(context.Background(), adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -532,8 +523,7 @@ func TestDB_GetAdmin(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
|
||||||
assert.Equals(t, adm.Id, adminID)
|
assert.Equals(t, adm.Id, adminID)
|
||||||
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
|
assert.Equals(t, adm.AuthorityId, tc.dba.AuthorityID)
|
||||||
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
|
assert.Equals(t, adm.ProvisionerId, tc.dba.ProvisionerID)
|
||||||
|
@ -542,7 +532,6 @@ func TestDB_GetAdmin(t *testing.T) {
|
||||||
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
|
assert.Equals(t, adm.CreatedAt, timestamppb.New(tc.dba.CreatedAt))
|
||||||
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
|
assert.Equals(t, adm.DeletedAt, timestamppb.New(tc.dba.DeletedAt))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -562,7 +551,7 @@ func TestDB_DeleteAdmin(t *testing.T) {
|
||||||
assert.Equals(t, bucket, adminsTable)
|
assert.Equals(t, bucket, adminsTable)
|
||||||
assert.Equals(t, string(key), adminID)
|
assert.Equals(t, string(key), adminID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
||||||
|
@ -670,8 +659,8 @@ func TestDB_DeleteAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.DeleteAdmin(context.Background(), adminID); err != nil {
|
if err := d.DeleteAdmin(context.Background(), adminID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -708,7 +697,7 @@ func TestDB_UpdateAdmin(t *testing.T) {
|
||||||
assert.Equals(t, bucket, adminsTable)
|
assert.Equals(t, bucket, adminsTable)
|
||||||
assert.Equals(t, string(key), adminID)
|
assert.Equals(t, string(key), adminID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "admin adminID not found"),
|
||||||
|
@ -821,8 +810,8 @@ func TestDB_UpdateAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.UpdateAdmin(context.Background(), tc.adm); err != nil {
|
if err := d.UpdateAdmin(context.Background(), tc.adm); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -919,8 +908,8 @@ func TestDB_CreateAdmin(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.CreateAdmin(context.Background(), tc.adm); err != nil {
|
if err := d.CreateAdmin(context.Background(), tc.adm); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -1095,8 +1084,8 @@ func TestDB_GetAdmins(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if admins, err := db.GetAdmins(context.Background()); err != nil {
|
if admins, err := d.GetAdmins(context.Background()); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -1111,11 +1100,9 @@ func TestDB_GetAdmins(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
|
||||||
tc.verify(t, admins)
|
tc.verify(t, admins)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ func New(db nosqlDB.DB, authorityID string) (*DB, error) {
|
||||||
|
|
||||||
// save writes the new data to the database, overwriting the old data if it
|
// save writes the new data to the database, overwriting the old data if it
|
||||||
// existed.
|
// existed.
|
||||||
func (db *DB) save(ctx context.Context, id string, nu interface{}, old interface{}, typ string, table []byte) error {
|
func (db *DB) save(ctx context.Context, id string, nu, old interface{}, typ string, table []byte) error {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
newB []byte
|
newB []byte
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"github.com/smallstep/certificates/db"
|
"github.com/smallstep/certificates/db"
|
||||||
"github.com/smallstep/nosql"
|
"github.com/smallstep/nosql"
|
||||||
"github.com/smallstep/nosql/database"
|
"github.com/smallstep/nosql/database"
|
||||||
nosqldb "github.com/smallstep/nosql/database"
|
|
||||||
"go.step.sm/linkedca"
|
"go.step.sm/linkedca"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -31,7 +30,7 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
|
||||||
assert.Equals(t, bucket, provisionersTable)
|
assert.Equals(t, bucket, provisionersTable)
|
||||||
assert.Equals(t, string(key), provID)
|
assert.Equals(t, string(key), provID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
||||||
|
@ -66,8 +65,8 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db}
|
d := DB{db: tc.db}
|
||||||
if b, err := db.getDBProvisionerBytes(context.Background(), provID); err != nil {
|
if b, err := d.getDBProvisionerBytes(context.Background(), provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -82,11 +81,9 @@ func TestDB_getDBProvisionerBytes(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
|
||||||
assert.Equals(t, string(b), "foo")
|
assert.Equals(t, string(b), "foo")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -107,7 +104,7 @@ func TestDB_getDBProvisioner(t *testing.T) {
|
||||||
assert.Equals(t, bucket, provisionersTable)
|
assert.Equals(t, bucket, provisionersTable)
|
||||||
assert.Equals(t, string(key), provID)
|
assert.Equals(t, string(key), provID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
||||||
|
@ -190,8 +187,8 @@ func TestDB_getDBProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if dbp, err := db.getDBProvisioner(context.Background(), provID); err != nil {
|
if dbp, err := d.getDBProvisioner(context.Background(), provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -206,8 +203,7 @@ func TestDB_getDBProvisioner(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
|
||||||
assert.Equals(t, dbp.ID, provID)
|
assert.Equals(t, dbp.ID, provID)
|
||||||
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
|
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
|
||||||
assert.Equals(t, dbp.Type, tc.dbp.Type)
|
assert.Equals(t, dbp.Type, tc.dbp.Type)
|
||||||
|
@ -215,7 +211,6 @@ func TestDB_getDBProvisioner(t *testing.T) {
|
||||||
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
|
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
|
||||||
assert.Fatal(t, dbp.DeletedAt.IsZero())
|
assert.Fatal(t, dbp.DeletedAt.IsZero())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -278,8 +273,8 @@ func TestDB_unmarshalDBProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{authorityID: admin.DefaultAuthorityID}
|
d := DB{authorityID: admin.DefaultAuthorityID}
|
||||||
if dbp, err := db.unmarshalDBProvisioner(tc.in, provID); err != nil {
|
if dbp, err := d.unmarshalDBProvisioner(tc.in, provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -294,8 +289,7 @@ func TestDB_unmarshalDBProvisioner(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
|
||||||
assert.Equals(t, dbp.ID, provID)
|
assert.Equals(t, dbp.ID, provID)
|
||||||
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
|
assert.Equals(t, dbp.AuthorityID, tc.dbp.AuthorityID)
|
||||||
assert.Equals(t, dbp.Type, tc.dbp.Type)
|
assert.Equals(t, dbp.Type, tc.dbp.Type)
|
||||||
|
@ -307,7 +301,6 @@ func TestDB_unmarshalDBProvisioner(t *testing.T) {
|
||||||
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
|
assert.Equals(t, dbp.CreatedAt, tc.dbp.CreatedAt)
|
||||||
assert.Fatal(t, dbp.DeletedAt.IsZero())
|
assert.Fatal(t, dbp.DeletedAt.IsZero())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -402,8 +395,8 @@ func TestDB_unmarshalProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{authorityID: admin.DefaultAuthorityID}
|
d := DB{authorityID: admin.DefaultAuthorityID}
|
||||||
if prov, err := db.unmarshalProvisioner(tc.in, provID); err != nil {
|
if prov, err := d.unmarshalProvisioner(tc.in, provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -418,8 +411,7 @@ func TestDB_unmarshalProvisioner(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
|
||||||
assert.Equals(t, prov.Id, provID)
|
assert.Equals(t, prov.Id, provID)
|
||||||
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
|
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
|
||||||
assert.Equals(t, prov.Type, tc.dbp.Type)
|
assert.Equals(t, prov.Type, tc.dbp.Type)
|
||||||
|
@ -432,7 +424,6 @@ func TestDB_unmarshalProvisioner(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, retDetailsBytes, tc.dbp.Details)
|
assert.Equals(t, retDetailsBytes, tc.dbp.Details)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -453,7 +444,7 @@ func TestDB_GetProvisioner(t *testing.T) {
|
||||||
assert.Equals(t, bucket, provisionersTable)
|
assert.Equals(t, bucket, provisionersTable)
|
||||||
assert.Equals(t, string(key), provID)
|
assert.Equals(t, string(key), provID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
||||||
|
@ -542,8 +533,8 @@ func TestDB_GetProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if prov, err := db.GetProvisioner(context.Background(), provID); err != nil {
|
if prov, err := d.GetProvisioner(context.Background(), provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -558,8 +549,7 @@ func TestDB_GetProvisioner(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
|
||||||
assert.Equals(t, prov.Id, provID)
|
assert.Equals(t, prov.Id, provID)
|
||||||
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
|
assert.Equals(t, prov.AuthorityId, tc.dbp.AuthorityID)
|
||||||
assert.Equals(t, prov.Type, tc.dbp.Type)
|
assert.Equals(t, prov.Type, tc.dbp.Type)
|
||||||
|
@ -572,7 +562,6 @@ func TestDB_GetProvisioner(t *testing.T) {
|
||||||
assert.FatalError(t, err)
|
assert.FatalError(t, err)
|
||||||
assert.Equals(t, retDetailsBytes, tc.dbp.Details)
|
assert.Equals(t, retDetailsBytes, tc.dbp.Details)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -592,7 +581,7 @@ func TestDB_DeleteProvisioner(t *testing.T) {
|
||||||
assert.Equals(t, bucket, provisionersTable)
|
assert.Equals(t, bucket, provisionersTable)
|
||||||
assert.Equals(t, string(key), provID)
|
assert.Equals(t, string(key), provID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
||||||
|
@ -692,8 +681,8 @@ func TestDB_DeleteProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.DeleteProvisioner(context.Background(), provID); err != nil {
|
if err := d.DeleteProvisioner(context.Background(), provID); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -853,8 +842,8 @@ func TestDB_GetProvisioners(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if provs, err := db.GetProvisioners(context.Background()); err != nil {
|
if provs, err := d.GetProvisioners(context.Background()); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -869,11 +858,9 @@ func TestDB_GetProvisioners(t *testing.T) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
||||||
if assert.Nil(t, tc.err) && assert.Nil(t, tc.adminErr) {
|
|
||||||
tc.verify(t, provs)
|
tc.verify(t, provs)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -963,8 +950,8 @@ func TestDB_CreateProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.CreateProvisioner(context.Background(), tc.prov); err != nil {
|
if err := d.CreateProvisioner(context.Background(), tc.prov); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
@ -1001,7 +988,7 @@ func TestDB_UpdateProvisioner(t *testing.T) {
|
||||||
assert.Equals(t, bucket, provisionersTable)
|
assert.Equals(t, bucket, provisionersTable)
|
||||||
assert.Equals(t, string(key), provID)
|
assert.Equals(t, string(key), provID)
|
||||||
|
|
||||||
return nil, nosqldb.ErrNotFound
|
return nil, database.ErrNotFound
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
adminErr: admin.NewError(admin.ErrorNotFoundType, "provisioner provID not found"),
|
||||||
|
@ -1199,8 +1186,8 @@ func TestDB_UpdateProvisioner(t *testing.T) {
|
||||||
for name, run := range tests {
|
for name, run := range tests {
|
||||||
tc := run(t)
|
tc := run(t)
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
db := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
d := DB{db: tc.db, authorityID: admin.DefaultAuthorityID}
|
||||||
if err := db.UpdateProvisioner(context.Background(), tc.prov); err != nil {
|
if err := d.UpdateProvisioner(context.Background(), tc.prov); err != nil {
|
||||||
switch k := err.(type) {
|
switch k := err.(type) {
|
||||||
case *admin.Error:
|
case *admin.Error:
|
||||||
if assert.NotNil(t, tc.adminErr) {
|
if assert.NotNil(t, tc.adminErr) {
|
||||||
|
|
|
@ -55,8 +55,8 @@ type subProv struct {
|
||||||
provisioner string
|
provisioner string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSubProv(subject, provisioner string) subProv {
|
func newSubProv(subject, prov string) subProv {
|
||||||
return subProv{subject, provisioner}
|
return subProv{subject, prov}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadBySubProv a admin by the subject and provisioner name.
|
// LoadBySubProv a admin by the subject and provisioner name.
|
||||||
|
|
|
@ -16,10 +16,10 @@ func (a *Authority) LoadAdminByID(id string) (*linkedca.Admin, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadAdminBySubProv returns an *linkedca.Admin with the given ID.
|
// LoadAdminBySubProv returns an *linkedca.Admin with the given ID.
|
||||||
func (a *Authority) LoadAdminBySubProv(subject, provisioner string) (*linkedca.Admin, bool) {
|
func (a *Authority) LoadAdminBySubProv(subject, prov string) (*linkedca.Admin, bool) {
|
||||||
a.adminMutex.RLock()
|
a.adminMutex.RLock()
|
||||||
defer a.adminMutex.RUnlock()
|
defer a.adminMutex.RUnlock()
|
||||||
return a.admins.LoadBySubProv(subject, provisioner)
|
return a.admins.LoadBySubProv(subject, prov)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAdmins returns a map listing each provisioner and the JWK Key Set
|
// GetAdmins returns a map listing each provisioner and the JWK Key Set
|
||||||
|
|
|
@ -78,14 +78,14 @@ type Authority struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates and initiates a new Authority type.
|
// New creates and initiates a new Authority type.
|
||||||
func New(config *config.Config, opts ...Option) (*Authority, error) {
|
func New(cfg *config.Config, opts ...Option) (*Authority, error) {
|
||||||
err := config.Validate()
|
err := cfg.Validate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var a = &Authority{
|
var a = &Authority{
|
||||||
config: config,
|
config: cfg,
|
||||||
certificates: new(sync.Map),
|
certificates: new(sync.Map),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -54,7 +54,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
||||||
// key in order to verify the claims and we need the issuer from the claims
|
// key in order to verify the claims and we need the issuer from the claims
|
||||||
// before we can look up the provisioner.
|
// before we can look up the provisioner.
|
||||||
var claims Claims
|
var claims Claims
|
||||||
if err = tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||||
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
|
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
|
||||||
// Store the token to protect against reuse unless it's skipped.
|
// Store the token to protect against reuse unless it's skipped.
|
||||||
// If we cannot get a token id from the provisioner, just hash the token.
|
// If we cannot get a token id from the provisioner, just hash the token.
|
||||||
if !SkipTokenReuseFromContext(ctx) {
|
if !SkipTokenReuseFromContext(ctx) {
|
||||||
if err = a.UseToken(token, p); err != nil {
|
if err := a.UseToken(token, p); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -112,7 +112,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
|
||||||
// to the public certificate in the `x5c` header of the token.
|
// to the public certificate in the `x5c` header of the token.
|
||||||
// 2. Asserts that the claims are valid - have not been tampered with.
|
// 2. Asserts that the claims are valid - have not been tampered with.
|
||||||
var claims jose.Claims
|
var claims jose.Claims
|
||||||
if err = jwt.Claims(leaf.PublicKey, &claims); err != nil {
|
if err := jwt.Claims(leaf.PublicKey, &claims); err != nil {
|
||||||
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c claims")
|
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error parsing x5c claims")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,13 +122,13 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the token has not been used.
|
// Check that the token has not been used.
|
||||||
if err = a.UseToken(token, prov); err != nil {
|
if err := a.UseToken(token, prov); err != nil {
|
||||||
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token")
|
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "adminHandler.authorizeToken; error with reuse token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
// According to "rfc7519 JSON Web Token" acceptable skew should be no
|
||||||
// more than a few minutes.
|
// more than a few minutes.
|
||||||
if err = claims.ValidateWithLeeway(jose.Expected{
|
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||||
Issuer: prov.GetName(),
|
Issuer: prov.GetName(),
|
||||||
Time: time.Now().UTC(),
|
Time: time.Now().UTC(),
|
||||||
}, time.Minute); err != nil {
|
}, time.Minute); err != nil {
|
||||||
|
@ -262,7 +262,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
||||||
}
|
}
|
||||||
if err = p.AuthorizeRevoke(ctx, token); err != nil {
|
if err := p.AuthorizeRevoke(ctx, token); err != nil {
|
||||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRevoke")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -917,7 +917,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if err = cert.SignCert(rand.Reader, signer); err != nil {
|
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return cert, jwk, nil
|
return cert, jwk, nil
|
||||||
|
|
|
@ -25,7 +25,7 @@ func (s multiString) HasEmpties() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
for _, ss := range s {
|
for _, ss := range s {
|
||||||
if len(ss) == 0 {
|
if ss == "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -272,12 +272,12 @@ func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificat
|
||||||
return errors.Wrap(err, "error revoking certificate")
|
return errors.Wrap(err, "error revoking certificate")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *linkedCaClient) RevokeSSH(ssh *ssh.Certificate, rci *db.RevokedCertificateInfo) error {
|
func (c *linkedCaClient) RevokeSSH(cert *ssh.Certificate, rci *db.RevokedCertificateInfo) error {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
_, err := c.client.RevokeSSHCertificate(ctx, &linkedca.RevokeSSHCertificateRequest{
|
_, err := c.client.RevokeSSHCertificate(ctx, &linkedca.RevokeSSHCertificateRequest{
|
||||||
Serial: rci.Serial,
|
Serial: rci.Serial,
|
||||||
Certificate: serializeSSHCertificate(ssh),
|
Certificate: serializeSSHCertificate(cert),
|
||||||
Reason: rci.Reason,
|
Reason: rci.Reason,
|
||||||
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
|
ReasonCode: linkedca.RevocationReasonCode(rci.ReasonCode),
|
||||||
Passive: true,
|
Passive: true,
|
||||||
|
|
|
@ -22,9 +22,9 @@ type Option func(*Authority) error
|
||||||
|
|
||||||
// WithConfig replaces the current config with the given one. No validation is
|
// WithConfig replaces the current config with the given one. No validation is
|
||||||
// performed in the given value.
|
// performed in the given value.
|
||||||
func WithConfig(config *config.Config) Option {
|
func WithConfig(cfg *config.Config) Option {
|
||||||
return func(a *Authority) error {
|
return func(a *Authority) error {
|
||||||
a.config = config
|
a.config = cfg
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -76,9 +76,9 @@ func WithIssuerPassword(password []byte) Option {
|
||||||
|
|
||||||
// WithDatabase sets an already initialized authority database to a new
|
// WithDatabase sets an already initialized authority database to a new
|
||||||
// authority. This option is intended to be use on graceful reloads.
|
// authority. This option is intended to be use on graceful reloads.
|
||||||
func WithDatabase(db db.AuthDB) Option {
|
func WithDatabase(d db.AuthDB) Option {
|
||||||
return func(a *Authority) error {
|
return func(a *Authority) error {
|
||||||
a.db = db
|
a.db = d
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -225,9 +225,9 @@ func WithX509FederatedBundle(pemCerts []byte) Option {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithAdminDB is an option to set the database backing the admin APIs.
|
// WithAdminDB is an option to set the database backing the admin APIs.
|
||||||
func WithAdminDB(db admin.DB) Option {
|
func WithAdminDB(d admin.DB) Option {
|
||||||
return func(a *Authority) error {
|
return func(a *Authority) error {
|
||||||
a.adminDB = db
|
a.adminDB = d
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -312,7 +312,7 @@ func (p *AWS) GetType() Type {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEncryptedKey is not available in an AWS provisioner.
|
// GetEncryptedKey is not available in an AWS provisioner.
|
||||||
func (p *AWS) GetEncryptedKey() (kid string, key string, ok bool) {
|
func (p *AWS) GetEncryptedKey() (kid, key string, ok bool) {
|
||||||
return "", "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -449,13 +449,15 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
||||||
// There's no way to trust them other than TOFU.
|
// There's no way to trust them other than TOFU.
|
||||||
var so []SignOption
|
var so []SignOption
|
||||||
if p.DisableCustomSANs {
|
if p.DisableCustomSANs {
|
||||||
dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region)
|
dnsName := fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region)
|
||||||
so = append(so, dnsNamesValidator([]string{dnsName}))
|
so = append(so,
|
||||||
so = append(so, ipAddressesValidator([]net.IP{
|
dnsNamesValidator([]string{dnsName}),
|
||||||
|
ipAddressesValidator([]net.IP{
|
||||||
net.ParseIP(doc.PrivateIP),
|
net.ParseIP(doc.PrivateIP),
|
||||||
}))
|
}),
|
||||||
so = append(so, emailAddressesValidator(nil))
|
emailAddressesValidator(nil),
|
||||||
so = append(so, urisValidator(nil))
|
urisValidator(nil),
|
||||||
|
)
|
||||||
|
|
||||||
// Template options
|
// Template options
|
||||||
data.SetSANs([]string{dnsName, doc.PrivateIP})
|
data.SetSANs([]string{dnsName, doc.PrivateIP})
|
||||||
|
@ -669,7 +671,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
||||||
if p.DisableCustomSANs {
|
if p.DisableCustomSANs {
|
||||||
if payload.Subject != doc.InstanceID &&
|
if payload.Subject != doc.InstanceID &&
|
||||||
payload.Subject != doc.PrivateIP &&
|
payload.Subject != doc.PrivateIP &&
|
||||||
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region) {
|
payload.Subject != fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region) {
|
||||||
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
|
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid subject claim (sub)")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -720,7 +722,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
||||||
// Validated principals.
|
// Validated principals.
|
||||||
principals := []string{
|
principals := []string{
|
||||||
doc.PrivateIP,
|
doc.PrivateIP,
|
||||||
fmt.Sprintf("ip-%s.%s.compute.internal", strings.Replace(doc.PrivateIP, ".", "-", -1), doc.Region),
|
fmt.Sprintf("ip-%s.%s.compute.internal", strings.ReplaceAll(doc.PrivateIP, ".", "-"), doc.Region),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only enforce known principals if disable custom sans is true.
|
// Only enforce known principals if disable custom sans is true.
|
||||||
|
|
|
@ -663,15 +663,15 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
got, err := tt.aws.AuthorizeSign(ctx, tt.args.token)
|
switch got, err := tt.aws.AuthorizeSign(ctx, tt.args.token); {
|
||||||
if (err != nil) != tt.wantErr {
|
case (err != nil) != tt.wantErr:
|
||||||
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AWS.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
case err != nil:
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(errs.StatusCoder)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
} else {
|
default:
|
||||||
assert.Len(t, tt.wantLen, got)
|
assert.Len(t, tt.wantLen, got)
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
|
|
@ -152,7 +152,7 @@ func (p *Azure) GetType() Type {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEncryptedKey is not available in an Azure provisioner.
|
// GetEncryptedKey is not available in an Azure provisioner.
|
||||||
func (p *Azure) GetEncryptedKey() (kid string, key string, ok bool) {
|
func (p *Azure) GetEncryptedKey() (kid, key string, ok bool) {
|
||||||
return "", "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -303,11 +303,13 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
||||||
var so []SignOption
|
var so []SignOption
|
||||||
if p.DisableCustomSANs {
|
if p.DisableCustomSANs {
|
||||||
// name will work only inside the virtual network
|
// name will work only inside the virtual network
|
||||||
so = append(so, commonNameValidator(name))
|
so = append(so,
|
||||||
so = append(so, dnsNamesValidator([]string{name}))
|
commonNameValidator(name),
|
||||||
so = append(so, ipAddressesValidator(nil))
|
dnsNamesValidator([]string{name}),
|
||||||
so = append(so, emailAddressesValidator(nil))
|
ipAddressesValidator(nil),
|
||||||
so = append(so, urisValidator(nil))
|
emailAddressesValidator(nil),
|
||||||
|
urisValidator(nil),
|
||||||
|
)
|
||||||
|
|
||||||
// Enforce SANs in the template.
|
// Enforce SANs in the template.
|
||||||
data.SetSANs([]string{name})
|
data.SetSANs([]string{name})
|
||||||
|
|
|
@ -446,15 +446,15 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
got, err := tt.azure.AuthorizeSign(ctx, tt.args.token)
|
switch got, err := tt.azure.AuthorizeSign(ctx, tt.args.token); {
|
||||||
if (err != nil) != tt.wantErr {
|
case (err != nil) != tt.wantErr:
|
||||||
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Azure.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
case err != nil:
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(errs.StatusCoder)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
} else {
|
default:
|
||||||
assert.Len(t, tt.wantLen, got)
|
assert.Len(t, tt.wantLen, got)
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
|
|
@ -229,7 +229,9 @@ func (c *Collection) Remove(id string) error {
|
||||||
|
|
||||||
var found bool
|
var found bool
|
||||||
for i, elem := range c.sorted {
|
for i, elem := range c.sorted {
|
||||||
if elem.provisioner.GetID() == id {
|
if elem.provisioner.GetID() != id {
|
||||||
|
continue
|
||||||
|
}
|
||||||
// Remove index in sorted list
|
// Remove index in sorted list
|
||||||
copy(c.sorted[i:], c.sorted[i+1:]) // Shift a[i+1:] left one index.
|
copy(c.sorted[i:], c.sorted[i+1:]) // Shift a[i+1:] left one index.
|
||||||
c.sorted[len(c.sorted)-1] = uidProvisioner{} // Erase last element (write zero value).
|
c.sorted[len(c.sorted)-1] = uidProvisioner{} // Erase last element (write zero value).
|
||||||
|
@ -237,7 +239,6 @@ func (c *Collection) Remove(id string) error {
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if !found {
|
if !found {
|
||||||
return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found in sorted list", prov.GetName())
|
return admin.NewError(admin.ErrorNotFoundType, "provisioner %s not found in sorted list", prov.GetName())
|
||||||
}
|
}
|
||||||
|
|
|
@ -150,7 +150,7 @@ func (p *GCP) GetType() Type {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEncryptedKey is not available in a GCP provisioner.
|
// GetEncryptedKey is not available in a GCP provisioner.
|
||||||
func (p *GCP) GetEncryptedKey() (kid string, key string, ok bool) {
|
func (p *GCP) GetEncryptedKey() (kid, key string, ok bool) {
|
||||||
return "", "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,15 +244,17 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
||||||
if p.DisableCustomSANs {
|
if p.DisableCustomSANs {
|
||||||
dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID)
|
dnsName1 := fmt.Sprintf("%s.c.%s.internal", ce.InstanceName, ce.ProjectID)
|
||||||
dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID)
|
dnsName2 := fmt.Sprintf("%s.%s.c.%s.internal", ce.InstanceName, ce.Zone, ce.ProjectID)
|
||||||
so = append(so, commonNameSliceValidator([]string{
|
so = append(so,
|
||||||
|
commonNameSliceValidator([]string{
|
||||||
ce.InstanceName, ce.InstanceID, dnsName1, dnsName2,
|
ce.InstanceName, ce.InstanceID, dnsName1, dnsName2,
|
||||||
}))
|
}),
|
||||||
so = append(so, dnsNamesValidator([]string{
|
dnsNamesValidator([]string{
|
||||||
dnsName1, dnsName2,
|
dnsName1, dnsName2,
|
||||||
}))
|
}),
|
||||||
so = append(so, ipAddressesValidator(nil))
|
ipAddressesValidator(nil),
|
||||||
so = append(so, emailAddressesValidator(nil))
|
emailAddressesValidator(nil),
|
||||||
so = append(so, urisValidator(nil))
|
urisValidator(nil),
|
||||||
|
)
|
||||||
|
|
||||||
// Template SANs
|
// Template SANs
|
||||||
data.SetSANs([]string{dnsName1, dnsName2})
|
data.SetSANs([]string{dnsName1, dnsName2})
|
||||||
|
|
|
@ -535,15 +535,15 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
ctx := NewContextWithMethod(context.Background(), SignMethod)
|
||||||
got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token)
|
switch got, err := tt.gcp.AuthorizeSign(ctx, tt.args.token); {
|
||||||
if (err != nil) != tt.wantErr {
|
case (err != nil) != tt.wantErr:
|
||||||
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("GCP.AuthorizeSign() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
case err != nil:
|
||||||
sc, ok := err.(errs.StatusCoder)
|
sc, ok := err.(errs.StatusCoder)
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
} else {
|
default:
|
||||||
assert.Len(t, tt.wantLen, got)
|
assert.Len(t, tt.wantLen, got)
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
|
|
|
@ -18,7 +18,7 @@ const (
|
||||||
defaultCacheJitter = 1 * time.Hour
|
defaultCacheJitter = 1 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
var maxAgeRegex = regexp.MustCompile("max-age=([0-9]+)")
|
var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`)
|
||||||
|
|
||||||
type keyStore struct {
|
type keyStore struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|
|
@ -29,7 +29,7 @@ func (p *noop) GetType() Type {
|
||||||
return noopType
|
return noopType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *noop) GetEncryptedKey() (kid string, key string, ok bool) {
|
func (p *noop) GetEncryptedKey() (kid, key string, ok bool) {
|
||||||
return "", "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -148,7 +148,7 @@ func (o *OIDC) GetType() Type {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEncryptedKey is not available in an OIDC provisioner.
|
// GetEncryptedKey is not available in an OIDC provisioner.
|
||||||
func (o *OIDC) GetEncryptedKey() (kid string, key string, ok bool) {
|
func (o *OIDC) GetEncryptedKey() (kid, key string, ok bool) {
|
||||||
return "", "", false
|
return "", "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,7 +193,7 @@ func (o *OIDC) Init(config Config) (err error) {
|
||||||
}
|
}
|
||||||
// Replace {tenantid} with the configured one
|
// Replace {tenantid} with the configured one
|
||||||
if o.TenantID != "" {
|
if o.TenantID != "" {
|
||||||
o.configuration.Issuer = strings.Replace(o.configuration.Issuer, "{tenantid}", o.TenantID, -1)
|
o.configuration.Issuer = strings.ReplaceAll(o.configuration.Issuer, "{tenantid}", o.TenantID)
|
||||||
}
|
}
|
||||||
// Get JWK key set
|
// Get JWK key set
|
||||||
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
|
o.keyStore, err = newKeyStore(o.configuration.JWKSetURI)
|
||||||
|
|
|
@ -321,13 +321,8 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
||||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||||
assert.Equals(t, sc.StatusCode(), tt.code)
|
assert.Equals(t, sc.StatusCode(), tt.code)
|
||||||
assert.Nil(t, got)
|
assert.Nil(t, got)
|
||||||
} else {
|
} else if assert.NotNil(t, got) {
|
||||||
if assert.NotNil(t, got) {
|
|
||||||
if tt.name == "admin" {
|
|
||||||
assert.Len(t, 5, got)
|
assert.Len(t, 5, got)
|
||||||
} else {
|
|
||||||
assert.Len(t, 5, got)
|
|
||||||
}
|
|
||||||
for _, o := range got {
|
for _, o := range got {
|
||||||
switch v := o.(type) {
|
switch v := o.(type) {
|
||||||
case certificateOptionsFunc:
|
case certificateOptionsFunc:
|
||||||
|
@ -349,7 +344,6 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,7 +138,7 @@ func unsafeParseSigned(s string) (map[string]interface{}, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
claims := make(map[string]interface{})
|
claims := make(map[string]interface{})
|
||||||
if err = token.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
if err := token.UnsafeClaimsWithoutVerification(&claims); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return claims, nil
|
return claims, nil
|
||||||
|
|
|
@ -123,7 +123,7 @@ func (a Audiences) WithFragment(fragment string) Audiences {
|
||||||
|
|
||||||
// generateSignAudience generates a sign audience with the format
|
// generateSignAudience generates a sign audience with the format
|
||||||
// https://<host>/1.0/sign#provisionerID
|
// https://<host>/1.0/sign#provisionerID
|
||||||
func generateSignAudience(caURL string, provisionerID string) (string, error) {
|
func generateSignAudience(caURL, provisionerID string) (string, error) {
|
||||||
u, err := url.Parse(caURL)
|
u, err := url.Parse(caURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrapf(err, "error parsing %s", caURL)
|
return "", errors.Wrapf(err, "error parsing %s", caURL)
|
||||||
|
|
|
@ -44,7 +44,7 @@ func TestSSHOptions_Modify(t *testing.T) {
|
||||||
valid func(*ssh.Certificate)
|
valid func(*ssh.Certificate)
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"fail/unexpected-cert-type": func() test {
|
"fail/unexpected-cert-type": func() test {
|
||||||
return test{
|
return test{
|
||||||
so: SignSSHOptions{CertType: "foo"},
|
so: SignSSHOptions{CertType: "foo"},
|
||||||
|
@ -117,7 +117,7 @@ func TestSSHOptions_Match(t *testing.T) {
|
||||||
cmp SignSSHOptions
|
cmp SignSSHOptions
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"fail/cert-type": func() test {
|
"fail/cert-type": func() test {
|
||||||
return test{
|
return test{
|
||||||
so: SignSSHOptions{CertType: "foo"},
|
so: SignSSHOptions{CertType: "foo"},
|
||||||
|
@ -208,7 +208,7 @@ func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
|
||||||
cert *ssh.Certificate
|
cert *ssh.Certificate
|
||||||
expected []string
|
expected []string
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"ok": func() test {
|
"ok": func() test {
|
||||||
a := []string{"foo", "bar"}
|
a := []string{"foo", "bar"}
|
||||||
return test{
|
return test{
|
||||||
|
@ -234,7 +234,7 @@ func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
|
||||||
cert *ssh.Certificate
|
cert *ssh.Certificate
|
||||||
expected string
|
expected string
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"ok": func() test {
|
"ok": func() test {
|
||||||
a := "foo"
|
a := "foo"
|
||||||
return test{
|
return test{
|
||||||
|
@ -260,7 +260,7 @@ func Test_sshCertTypeModifier_Modify(t *testing.T) {
|
||||||
cert *ssh.Certificate
|
cert *ssh.Certificate
|
||||||
expected uint32
|
expected uint32
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"ok/user": func() test {
|
"ok/user": func() test {
|
||||||
return test{
|
return test{
|
||||||
modifier: sshCertTypeModifier("user"),
|
modifier: sshCertTypeModifier("user"),
|
||||||
|
@ -299,7 +299,7 @@ func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
|
||||||
cert *ssh.Certificate
|
cert *ssh.Certificate
|
||||||
expected uint64
|
expected uint64
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"ok": func() test {
|
"ok": func() test {
|
||||||
return test{
|
return test{
|
||||||
modifier: sshCertValidAfterModifier(15),
|
modifier: sshCertValidAfterModifier(15),
|
||||||
|
@ -324,7 +324,7 @@ func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
|
||||||
cert *ssh.Certificate
|
cert *ssh.Certificate
|
||||||
valid func(*ssh.Certificate)
|
valid func(*ssh.Certificate)
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"ok/changes": func() test {
|
"ok/changes": func() test {
|
||||||
n := time.Now()
|
n := time.Now()
|
||||||
va := NewTimeDuration(n.Add(1 * time.Minute))
|
va := NewTimeDuration(n.Add(1 * time.Minute))
|
||||||
|
@ -388,7 +388,7 @@ func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
|
||||||
valid func(*ssh.Certificate)
|
valid func(*ssh.Certificate)
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
tests := map[string](func() test){
|
tests := map[string]func() test{
|
||||||
"fail/unexpected-cert-type": func() test {
|
"fail/unexpected-cert-type": func() test {
|
||||||
cert := &ssh.Certificate{CertType: 3}
|
cert := &ssh.Certificate{CertType: 3}
|
||||||
return test{
|
return test{
|
||||||
|
|
|
@ -46,7 +46,7 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
if err = cert.SignCert(rand.Reader, signer); err != nil {
|
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
return cert, jwk, nil
|
return cert, jwk, nil
|
||||||
|
@ -214,11 +214,9 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
||||||
if assert.NotNil(t, tc.err) {
|
if assert.NotNil(t, tc.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.err.Error())
|
||||||
}
|
}
|
||||||
} else {
|
} else if assert.Nil(t, tc.err) {
|
||||||
if assert.Nil(t, tc.err) {
|
|
||||||
assert.NotNil(t, claims)
|
assert.NotNil(t, claims)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -732,7 +732,7 @@ func withSSHPOPFile(cert *ssh.Certificate) tokOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateToken(sub, iss, aud string, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
|
func generateToken(sub, iss, aud, email string, sans []string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
|
||||||
so := new(jose.SignerOptions)
|
so := new(jose.SignerOptions)
|
||||||
so.WithType("JWT")
|
so.WithType("JWT")
|
||||||
so.WithHeader("kid", jwk.KeyID)
|
so.WithHeader("kid", jwk.KeyID)
|
||||||
|
@ -773,7 +773,7 @@ func generateToken(sub, iss, aud string, email string, sans []string, iat time.T
|
||||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateOIDCToken(sub, iss, aud string, email string, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
|
func generateOIDCToken(sub, iss, aud, email, preferredUsername string, iat time.Time, jwk *jose.JSONWebKey, tokOpts ...tokOption) (string, error) {
|
||||||
so := new(jose.SignerOptions)
|
so := new(jose.SignerOptions)
|
||||||
so.WithType("JWT")
|
so.WithType("JWT")
|
||||||
so.WithHeader("kid", jwk.KeyID)
|
so.WithHeader("kid", jwk.KeyID)
|
||||||
|
|
|
@ -108,7 +108,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
|
||||||
|
|
||||||
// GetSSHBastion returns the bastion configuration, for the given pair user,
|
// GetSSHBastion returns the bastion configuration, for the given pair user,
|
||||||
// hostname.
|
// hostname.
|
||||||
func (a *Authority) GetSSHBastion(ctx context.Context, user string, hostname string) (*config.Bastion, error) {
|
func (a *Authority) GetSSHBastion(ctx context.Context, user, hostname string) (*config.Bastion, error) {
|
||||||
if a.sshBastionFunc != nil {
|
if a.sshBastionFunc != nil {
|
||||||
bs, err := a.sshBastionFunc(ctx, user, hostname)
|
bs, err := a.sshBastionFunc(ctx, user, hostname)
|
||||||
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
|
return bs, errs.Wrap(http.StatusInternalServerError, err, "authority.GetSSHBastion")
|
||||||
|
@ -477,7 +477,7 @@ func (a *Authority) SignSSHAddUser(ctx context.Context, key ssh.PublicKey, subje
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckSSHHost checks the given principal has been registered before.
|
// CheckSSHHost checks the given principal has been registered before.
|
||||||
func (a *Authority) CheckSSHHost(ctx context.Context, principal string, token string) (bool, error) {
|
func (a *Authority) CheckSSHHost(ctx context.Context, principal, token string) (bool, error) {
|
||||||
if a.sshCheckHostFunc != nil {
|
if a.sshCheckHostFunc != nil {
|
||||||
exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates())
|
exists, err := a.sshCheckHostFunc(ctx, principal, token, a.GetRootCertificates())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -531,5 +531,5 @@ func (a *Authority) getAddUserCommand(principal string) string {
|
||||||
} else {
|
} else {
|
||||||
cmd = a.config.SSH.AddUserCommand
|
cmd = a.config.SSH.AddUserCommand
|
||||||
}
|
}
|
||||||
return strings.Replace(cmd, "<principal>", principal, -1)
|
return strings.ReplaceAll(cmd, "<principal>", principal)
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,10 +55,10 @@ func withDefaultASN1DN(def *config.ASN1DN) provisioner.CertificateModifierFunc {
|
||||||
if len(crt.Subject.StreetAddress) == 0 && def.StreetAddress != "" {
|
if len(crt.Subject.StreetAddress) == 0 && def.StreetAddress != "" {
|
||||||
crt.Subject.StreetAddress = append(crt.Subject.StreetAddress, def.StreetAddress)
|
crt.Subject.StreetAddress = append(crt.Subject.StreetAddress, def.StreetAddress)
|
||||||
}
|
}
|
||||||
if len(crt.Subject.SerialNumber) == 0 && def.SerialNumber != "" {
|
if crt.Subject.SerialNumber == "" && def.SerialNumber != "" {
|
||||||
crt.Subject.SerialNumber = def.SerialNumber
|
crt.Subject.SerialNumber = def.SerialNumber
|
||||||
}
|
}
|
||||||
if len(crt.Subject.CommonName) == 0 && def.CommonName != "" {
|
if crt.Subject.CommonName == "" && def.CommonName != "" {
|
||||||
crt.Subject.CommonName = def.CommonName
|
crt.Subject.CommonName = def.CommonName
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -387,15 +387,15 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
|
||||||
return errs.Wrap(http.StatusInternalServerError, err,
|
return errs.Wrap(http.StatusInternalServerError, err,
|
||||||
"authority.Revoke; could not get ID for token")
|
"authority.Revoke; could not get ID for token")
|
||||||
}
|
}
|
||||||
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
|
opts = append(opts,
|
||||||
opts = append(opts, errs.WithKeyVal("tokenID", rci.TokenID))
|
errs.WithKeyVal("provisionerID", rci.ProvisionerID),
|
||||||
} else {
|
errs.WithKeyVal("tokenID", rci.TokenID),
|
||||||
|
)
|
||||||
|
} else if p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt); err == nil {
|
||||||
// Load the Certificate provisioner if one exists.
|
// Load the Certificate provisioner if one exists.
|
||||||
if p, err = a.LoadProvisionerByCertificate(revokeOpts.Crt); err == nil {
|
|
||||||
rci.ProvisionerID = p.GetID()
|
rci.ProvisionerID = p.GetID()
|
||||||
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
|
opts = append(opts, errs.WithKeyVal("provisionerID", rci.ProvisionerID))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod {
|
if provisioner.MethodFromContext(ctx) == provisioner.SSHRevokeMethod {
|
||||||
err = a.revokeSSH(nil, rci)
|
err = a.revokeSSH(nil, rci)
|
||||||
|
|
|
@ -426,6 +426,7 @@ ZYtQ9Ot36qc=
|
||||||
{Id: stepOIDProvisioner, Value: []byte("foo")},
|
{Id: stepOIDProvisioner, Value: []byte("foo")},
|
||||||
{Id: []int{1, 1, 1}, Value: []byte("bar")}}))
|
{Id: []int{1, 1, 1}, Value: []byte("bar")}}))
|
||||||
now := time.Now().UTC()
|
now := time.Now().UTC()
|
||||||
|
// nolint:gocritic
|
||||||
enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{
|
enforcedExtraOptions := append(extraOpts, &certificateDurationEnforcer{
|
||||||
NotBefore: now,
|
NotBefore: now,
|
||||||
NotAfter: now.Add(365 * 24 * time.Hour),
|
NotAfter: now.Add(365 * 24 * time.Hour),
|
||||||
|
|
|
@ -345,7 +345,7 @@ func readACMEError(r io.ReadCloser) error {
|
||||||
ae := new(acme.Error)
|
ae := new(acme.Error)
|
||||||
err = json.Unmarshal(b, &ae)
|
err = json.Unmarshal(b, &ae)
|
||||||
// If we successfully marshaled to an ACMEError then return the ACMEError.
|
// If we successfully marshaled to an ACMEError then return the ACMEError.
|
||||||
if err != nil || len(ae.Error()) == 0 {
|
if err != nil || ae.Error() == "" {
|
||||||
fmt.Printf("b = %s\n", b)
|
fmt.Printf("b = %s\n", b)
|
||||||
// Throw up our hands.
|
// Throw up our hands.
|
||||||
return errors.Errorf("%s", b)
|
return errors.Errorf("%s", b)
|
||||||
|
|
|
@ -1247,6 +1247,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
|
||||||
Type: "Certificate",
|
Type: "Certificate",
|
||||||
Bytes: leaf.Raw,
|
Bytes: leaf.Raw,
|
||||||
})
|
})
|
||||||
|
// nolint:gocritic
|
||||||
certBytes := append(leafb, leafb...)
|
certBytes := append(leafb, leafb...)
|
||||||
certBytes = append(certBytes, leafb...)
|
certBytes = append(certBytes, leafb...)
|
||||||
ac := &ACMEClient{
|
ac := &ACMEClient{
|
||||||
|
|
|
@ -70,7 +70,7 @@ func NewAdminClient(endpoint string, opts ...ClientOption) (*AdminClient, error)
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AdminClient) generateAdminToken(path string) (string, error) {
|
func (c *AdminClient) generateAdminToken(urlPath string) (string, error) {
|
||||||
// A random jwt id will be used to identify duplicated tokens
|
// A random jwt id will be used to identify duplicated tokens
|
||||||
jwtID, err := randutil.Hex(64) // 256 bits
|
jwtID, err := randutil.Hex(64) // 256 bits
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -82,7 +82,7 @@ func (c *AdminClient) generateAdminToken(path string) (string, error) {
|
||||||
token.WithJWTID(jwtID),
|
token.WithJWTID(jwtID),
|
||||||
token.WithKid(c.x5cJWK.KeyID),
|
token.WithKid(c.x5cJWK.KeyID),
|
||||||
token.WithIssuer(c.x5cIssuer),
|
token.WithIssuer(c.x5cIssuer),
|
||||||
token.WithAudience(path),
|
token.WithAudience(urlPath),
|
||||||
token.WithValidity(now, now.Add(token.DefaultValidity)),
|
token.WithValidity(now, now.Add(token.DefaultValidity)),
|
||||||
token.WithX5CCerts(c.x5cCertStrs),
|
token.WithX5CCerts(c.x5cCertStrs),
|
||||||
}
|
}
|
||||||
|
@ -348,14 +348,15 @@ func (c *AdminClient) GetProvisioner(opts ...ProvisionerOption) (*linkedca.Provi
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var u *url.URL
|
var u *url.URL
|
||||||
if len(o.id) > 0 {
|
switch {
|
||||||
|
case len(o.id) > 0:
|
||||||
u = c.endpoint.ResolveReference(&url.URL{
|
u = c.endpoint.ResolveReference(&url.URL{
|
||||||
Path: "/admin/provisioners/id",
|
Path: "/admin/provisioners/id",
|
||||||
RawQuery: o.rawQuery(),
|
RawQuery: o.rawQuery(),
|
||||||
})
|
})
|
||||||
} else if len(o.name) > 0 {
|
case len(o.name) > 0:
|
||||||
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
|
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
|
||||||
} else {
|
default:
|
||||||
return nil, errors.New("must set either name or id in method options")
|
return nil, errors.New("must set either name or id in method options")
|
||||||
}
|
}
|
||||||
tok, err := c.generateAdminToken(u.Path)
|
tok, err := c.generateAdminToken(u.Path)
|
||||||
|
@ -456,14 +457,15 @@ func (c *AdminClient) RemoveProvisioner(opts ...ProvisionerOption) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(o.id) > 0 {
|
switch {
|
||||||
|
case len(o.id) > 0:
|
||||||
u = c.endpoint.ResolveReference(&url.URL{
|
u = c.endpoint.ResolveReference(&url.URL{
|
||||||
Path: path.Join(adminURLPrefix, "provisioners/id"),
|
Path: path.Join(adminURLPrefix, "provisioners/id"),
|
||||||
RawQuery: o.rawQuery(),
|
RawQuery: o.rawQuery(),
|
||||||
})
|
})
|
||||||
} else if len(o.name) > 0 {
|
case len(o.name) > 0:
|
||||||
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
|
u = c.endpoint.ResolveReference(&url.URL{Path: path.Join(adminURLPrefix, "provisioners", o.name)})
|
||||||
} else {
|
default:
|
||||||
return errors.New("must set either name or id in method options")
|
return errors.New("must set either name or id in method options")
|
||||||
}
|
}
|
||||||
tok, err := c.generateAdminToken(u.Path)
|
tok, err := c.generateAdminToken(u.Path)
|
||||||
|
|
|
@ -30,7 +30,7 @@ func Bootstrap(token string) (*Client, error) {
|
||||||
|
|
||||||
// Validate bootstrap token
|
// Validate bootstrap token
|
||||||
switch {
|
switch {
|
||||||
case len(claims.SHA) == 0:
|
case claims.SHA == "":
|
||||||
return nil, errors.New("invalid bootstrap token: sha claim is not present")
|
return nil, errors.New("invalid bootstrap token: sha claim is not present")
|
||||||
case !strings.HasPrefix(strings.ToLower(claims.Audience[0]), "http"):
|
case !strings.HasPrefix(strings.ToLower(claims.Audience[0]), "http"):
|
||||||
return nil, errors.New("invalid bootstrap token: aud claim is not a url")
|
return nil, errors.New("invalid bootstrap token: aud claim is not a url")
|
||||||
|
|
52
ca/ca.go
52
ca/ca.go
|
@ -88,9 +88,9 @@ func WithIssuerPassword(password []byte) Option {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithDatabase sets the given authority database to the CA options.
|
// WithDatabase sets the given authority database to the CA options.
|
||||||
func WithDatabase(db db.AuthDB) Option {
|
func WithDatabase(d db.AuthDB) Option {
|
||||||
return func(o *options) {
|
return func(o *options) {
|
||||||
o.database = db
|
o.database = d
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,17 +113,17 @@ type CA struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates and initializes the CA with the given configuration and options.
|
// New creates and initializes the CA with the given configuration and options.
|
||||||
func New(config *config.Config, opts ...Option) (*CA, error) {
|
func New(cfg *config.Config, opts ...Option) (*CA, error) {
|
||||||
ca := &CA{
|
ca := &CA{
|
||||||
config: config,
|
config: cfg,
|
||||||
opts: new(options),
|
opts: new(options),
|
||||||
}
|
}
|
||||||
ca.opts.apply(opts)
|
ca.opts.apply(opts)
|
||||||
return ca.Init(config)
|
return ca.Init(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the CA with the given configuration.
|
// Init initializes the CA with the given configuration.
|
||||||
func (ca *CA) Init(config *config.Config) (*CA, error) {
|
func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
||||||
// Set password, it's ok to set nil password, the ca will prompt for them if
|
// Set password, it's ok to set nil password, the ca will prompt for them if
|
||||||
// they are required.
|
// they are required.
|
||||||
opts := []authority.Option{
|
opts := []authority.Option{
|
||||||
|
@ -140,7 +140,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
opts = append(opts, authority.WithDatabase(ca.opts.database))
|
opts = append(opts, authority.WithDatabase(ca.opts.database))
|
||||||
}
|
}
|
||||||
|
|
||||||
auth, err := authority.New(config, opts...)
|
auth, err := authority.New(cfg, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -166,8 +166,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
})
|
})
|
||||||
|
|
||||||
//Add ACME api endpoints in /acme and /1.0/acme
|
//Add ACME api endpoints in /acme and /1.0/acme
|
||||||
dns := config.DNSNames[0]
|
dns := cfg.DNSNames[0]
|
||||||
u, err := url.Parse("https://" + config.Address)
|
u, err := url.Parse("https://" + cfg.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -179,7 +179,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
// ACME Router
|
// ACME Router
|
||||||
prefix := "acme"
|
prefix := "acme"
|
||||||
var acmeDB acme.DB
|
var acmeDB acme.DB
|
||||||
if config.DB == nil {
|
if cfg.DB == nil {
|
||||||
acmeDB = nil
|
acmeDB = nil
|
||||||
} else {
|
} else {
|
||||||
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
|
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB))
|
||||||
|
@ -188,7 +188,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
|
acmeHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{
|
||||||
Backdate: *config.AuthorityConfig.Backdate,
|
Backdate: *cfg.AuthorityConfig.Backdate,
|
||||||
DB: acmeDB,
|
DB: acmeDB,
|
||||||
DNS: dns,
|
DNS: dns,
|
||||||
Prefix: prefix,
|
Prefix: prefix,
|
||||||
|
@ -204,7 +204,7 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
})
|
})
|
||||||
|
|
||||||
// Admin API Router
|
// Admin API Router
|
||||||
if config.AuthorityConfig.EnableAdmin {
|
if cfg.AuthorityConfig.EnableAdmin {
|
||||||
adminDB := auth.GetAdminDatabase()
|
adminDB := auth.GetAdminDatabase()
|
||||||
if adminDB != nil {
|
if adminDB != nil {
|
||||||
adminHandler := adminAPI.NewHandler(auth)
|
adminHandler := adminAPI.NewHandler(auth)
|
||||||
|
@ -248,8 +248,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
//dumpRoutes(mux)
|
//dumpRoutes(mux)
|
||||||
|
|
||||||
// Add monitoring if configured
|
// Add monitoring if configured
|
||||||
if len(config.Monitoring) > 0 {
|
if len(cfg.Monitoring) > 0 {
|
||||||
m, err := monitoring.New(config.Monitoring)
|
m, err := monitoring.New(cfg.Monitoring)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -258,8 +258,8 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add logger if configured
|
// Add logger if configured
|
||||||
if len(config.Logger) > 0 {
|
if len(cfg.Logger) > 0 {
|
||||||
logger, err := logging.New("ca", config.Logger)
|
logger, err := logging.New("ca", cfg.Logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -267,16 +267,16 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
insecureHandler = logger.Middleware(insecureHandler)
|
insecureHandler = logger.Middleware(insecureHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
ca.srv = server.New(config.Address, handler, tlsConfig)
|
ca.srv = server.New(cfg.Address, handler, tlsConfig)
|
||||||
|
|
||||||
// only start the insecure server if the insecure address is configured
|
// only start the insecure server if the insecure address is configured
|
||||||
// and, currently, also only when it should serve SCEP endpoints.
|
// and, currently, also only when it should serve SCEP endpoints.
|
||||||
if ca.shouldServeSCEPEndpoints() && config.InsecureAddress != "" {
|
if ca.shouldServeSCEPEndpoints() && cfg.InsecureAddress != "" {
|
||||||
// TODO: instead opt for having a single server.Server but two
|
// TODO: instead opt for having a single server.Server but two
|
||||||
// http.Servers handling the HTTP and HTTPS handler? The latter
|
// http.Servers handling the HTTP and HTTPS handler? The latter
|
||||||
// will probably introduce more complexity in terms of graceful
|
// will probably introduce more complexity in terms of graceful
|
||||||
// reload.
|
// reload.
|
||||||
ca.insecureSrv = server.New(config.InsecureAddress, insecureHandler, nil)
|
ca.insecureSrv = server.New(cfg.InsecureAddress, insecureHandler, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ca, nil
|
return ca, nil
|
||||||
|
@ -285,24 +285,24 @@ func (ca *CA) Init(config *config.Config) (*CA, error) {
|
||||||
// Run starts the CA calling to the server ListenAndServe method.
|
// Run starts the CA calling to the server ListenAndServe method.
|
||||||
func (ca *CA) Run() error {
|
func (ca *CA) Run() error {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
errors := make(chan error, 1)
|
errs := make(chan error, 1)
|
||||||
|
|
||||||
if ca.insecureSrv != nil {
|
if ca.insecureSrv != nil {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
errors <- ca.insecureSrv.ListenAndServe()
|
errs <- ca.insecureSrv.ListenAndServe()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
errors <- ca.srv.ListenAndServe()
|
errs <- ca.srv.ListenAndServe()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// wait till error occurs; ensures the servers keep listening
|
// wait till error occurs; ensures the servers keep listening
|
||||||
err := <-errors
|
err := <-errs
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
|
@ -331,7 +331,7 @@ func (ca *CA) Stop() error {
|
||||||
// Reload reloads the configuration of the CA and calls to the server Reload
|
// Reload reloads the configuration of the CA and calls to the server Reload
|
||||||
// method.
|
// method.
|
||||||
func (ca *CA) Reload() error {
|
func (ca *CA) Reload() error {
|
||||||
config, err := config.LoadConfiguration(ca.opts.configFile)
|
cfg, err := config.LoadConfiguration(ca.opts.configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "error reloading ca configuration")
|
return errors.Wrap(err, "error reloading ca configuration")
|
||||||
}
|
}
|
||||||
|
@ -343,12 +343,12 @@ func (ca *CA) Reload() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do not allow reload if the database configuration has changed.
|
// Do not allow reload if the database configuration has changed.
|
||||||
if !reflect.DeepEqual(ca.config.DB, config.DB) {
|
if !reflect.DeepEqual(ca.config.DB, cfg.DB) {
|
||||||
logContinue("Reload failed because the database configuration has changed.")
|
logContinue("Reload failed because the database configuration has changed.")
|
||||||
return errors.New("error reloading ca: database configuration cannot change")
|
return errors.New("error reloading ca: database configuration cannot change")
|
||||||
}
|
}
|
||||||
|
|
||||||
newCA, err := New(config,
|
newCA, err := New(cfg,
|
||||||
WithPassword(ca.opts.password),
|
WithPassword(ca.opts.password),
|
||||||
WithSSHHostPassword(ca.opts.sshHostPassword),
|
WithSSHHostPassword(ca.opts.sshHostPassword),
|
||||||
WithSSHUserPassword(ca.opts.sshUserPassword),
|
WithSSHUserPassword(ca.opts.sshUserPassword),
|
||||||
|
|
|
@ -322,7 +322,7 @@ ZEp7knvU2psWRw==
|
||||||
assert.Equals(t, intermediate, realIntermediate)
|
assert.Equals(t, intermediate, realIntermediate)
|
||||||
} else {
|
} else {
|
||||||
err := readError(body)
|
err := readError(body)
|
||||||
if len(tc.errMsg) == 0 {
|
if tc.errMsg == "" {
|
||||||
assert.FatalError(t, errors.New("must validate response error"))
|
assert.FatalError(t, errors.New("must validate response error"))
|
||||||
}
|
}
|
||||||
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
||||||
|
@ -375,7 +375,7 @@ func TestCAProvisioners(t *testing.T) {
|
||||||
assert.Equals(t, a, b)
|
assert.Equals(t, a, b)
|
||||||
} else {
|
} else {
|
||||||
err := readError(body)
|
err := readError(body)
|
||||||
if len(tc.errMsg) == 0 {
|
if tc.errMsg == "" {
|
||||||
assert.FatalError(t, errors.New("must validate response error"))
|
assert.FatalError(t, errors.New("must validate response error"))
|
||||||
}
|
}
|
||||||
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
||||||
|
@ -436,7 +436,7 @@ func TestCAProvisionerEncryptedKey(t *testing.T) {
|
||||||
assert.Equals(t, ek.Key, tc.expectedKey)
|
assert.Equals(t, ek.Key, tc.expectedKey)
|
||||||
} else {
|
} else {
|
||||||
err := readError(body)
|
err := readError(body)
|
||||||
if len(tc.errMsg) == 0 {
|
if tc.errMsg == "" {
|
||||||
assert.FatalError(t, errors.New("must validate response error"))
|
assert.FatalError(t, errors.New("must validate response error"))
|
||||||
}
|
}
|
||||||
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
||||||
|
@ -497,7 +497,7 @@ func TestCARoot(t *testing.T) {
|
||||||
assert.Equals(t, root.RootPEM.Certificate, rootCrt)
|
assert.Equals(t, root.RootPEM.Certificate, rootCrt)
|
||||||
} else {
|
} else {
|
||||||
err := readError(body)
|
err := readError(body)
|
||||||
if len(tc.errMsg) == 0 {
|
if tc.errMsg == "" {
|
||||||
assert.FatalError(t, errors.New("must validate response error"))
|
assert.FatalError(t, errors.New("must validate response error"))
|
||||||
}
|
}
|
||||||
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
||||||
|
@ -665,7 +665,7 @@ func TestCARenew(t *testing.T) {
|
||||||
assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions)
|
assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions)
|
||||||
} else {
|
} else {
|
||||||
err := readError(body)
|
err := readError(body)
|
||||||
if len(tc.errMsg) == 0 {
|
if tc.errMsg == "" {
|
||||||
assert.FatalError(t, errors.New("must validate response error"))
|
assert.FatalError(t, errors.New("must validate response error"))
|
||||||
}
|
}
|
||||||
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
assert.HasPrefix(t, err.Error(), tc.errMsg)
|
||||||
|
|
32
ca/client.go
32
ca/client.go
|
@ -74,17 +74,17 @@ func (c *uaClient) SetTransport(tr http.RoundTripper) {
|
||||||
c.Client.Transport = tr
|
c.Client.Transport = tr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *uaClient) Get(url string) (*http.Response, error) {
|
func (c *uaClient) Get(u string) (*http.Response, error) {
|
||||||
req, err := http.NewRequest("GET", url, nil)
|
req, err := http.NewRequest("GET", u, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrapf(err, "new request GET %s failed", url)
|
return nil, errors.Wrapf(err, "new request GET %s failed", u)
|
||||||
}
|
}
|
||||||
req.Header.Set("User-Agent", UserAgent)
|
req.Header.Set("User-Agent", UserAgent)
|
||||||
return c.Client.Do(req)
|
return c.Client.Do(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *uaClient) Post(url, contentType string, body io.Reader) (*http.Response, error) {
|
func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) {
|
||||||
req, err := http.NewRequest("POST", url, body)
|
req, err := http.NewRequest("POST", u, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -305,7 +305,7 @@ func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile strin
|
||||||
err error
|
err error
|
||||||
opts []jose.Option
|
opts []jose.Option
|
||||||
)
|
)
|
||||||
if len(passwordFile) != 0 {
|
if passwordFile != "" {
|
||||||
opts = append(opts, jose.WithPasswordFile(passwordFile))
|
opts = append(opts, jose.WithPasswordFile(passwordFile))
|
||||||
}
|
}
|
||||||
blk, err := pemutil.Serialize(key)
|
blk, err := pemutil.Serialize(key)
|
||||||
|
@ -326,14 +326,14 @@ func WithAdminX5C(certs []*x509.Certificate, key interface{}, passwordFile strin
|
||||||
|
|
||||||
for _, e := range o.x5cCert.Extensions {
|
for _, e := range o.x5cCert.Extensions {
|
||||||
if e.Id.Equal(stepOIDProvisioner) {
|
if e.Id.Equal(stepOIDProvisioner) {
|
||||||
var provisioner stepProvisionerASN1
|
var prov stepProvisionerASN1
|
||||||
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
if _, err := asn1.Unmarshal(e.Value, &prov); err != nil {
|
||||||
return errors.Wrap(err, "error unmarshaling provisioner OID from certificate")
|
return errors.Wrap(err, "error unmarshaling provisioner OID from certificate")
|
||||||
}
|
}
|
||||||
o.x5cIssuer = string(provisioner.Name)
|
o.x5cIssuer = string(prov.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(o.x5cIssuer) == 0 {
|
if o.x5cIssuer == "" {
|
||||||
return errors.New("provisioner extension not found in certificate")
|
return errors.New("provisioner extension not found in certificate")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -631,7 +631,7 @@ retry:
|
||||||
// do not match.
|
// do not match.
|
||||||
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
|
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
|
||||||
var retried bool
|
var retried bool
|
||||||
sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1))
|
sha256Sum = strings.ToLower(strings.ReplaceAll(sha256Sum, "-", ""))
|
||||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
|
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
|
||||||
retry:
|
retry:
|
||||||
resp, err := newInsecureClient().Get(u.String())
|
resp, err := newInsecureClient().Get(u.String())
|
||||||
|
@ -651,7 +651,7 @@ retry:
|
||||||
}
|
}
|
||||||
// verify the sha256
|
// verify the sha256
|
||||||
sum := sha256.Sum256(root.RootPEM.Raw)
|
sum := sha256.Sum256(root.RootPEM.Raw)
|
||||||
if sha256Sum != strings.ToLower(hex.EncodeToString(sum[:])) {
|
if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) {
|
||||||
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
|
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match")
|
||||||
}
|
}
|
||||||
return &root, nil
|
return &root, nil
|
||||||
|
@ -1066,16 +1066,16 @@ retry:
|
||||||
}
|
}
|
||||||
return nil, readError(resp.Body)
|
return nil, readError(resp.Body)
|
||||||
}
|
}
|
||||||
var config api.SSHConfigResponse
|
var cfg api.SSHConfigResponse
|
||||||
if err := readJSON(resp.Body, &config); err != nil {
|
if err := readJSON(resp.Body, &cfg); err != nil {
|
||||||
return nil, errors.Wrapf(err, "error reading %s", u)
|
return nil, errors.Wrapf(err, "error reading %s", u)
|
||||||
}
|
}
|
||||||
return &config, nil
|
return &cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the
|
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the
|
||||||
// given principal.
|
// given principal.
|
||||||
func (c *Client) SSHCheckHost(principal string, token string) (*api.SSHCheckPrincipalResponse, error) {
|
func (c *Client) SSHCheckHost(principal, token string) (*api.SSHCheckPrincipalResponse, error) {
|
||||||
var retried bool
|
var retried bool
|
||||||
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
|
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
|
||||||
Type: provisioner.SSHHostCert,
|
Type: provisioner.SSHHostCert,
|
||||||
|
|
|
@ -135,7 +135,7 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
|
||||||
return csr
|
return csr
|
||||||
}
|
}
|
||||||
|
|
||||||
func equalJSON(t *testing.T, a interface{}, b interface{}) bool {
|
func equalJSON(t *testing.T, a, b interface{}) bool {
|
||||||
if reflect.DeepEqual(a, b) {
|
if reflect.DeepEqual(a, b) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -187,11 +187,12 @@ func TestLoadClient(t *testing.T) {
|
||||||
} else {
|
} else {
|
||||||
gotTransport := got.Client.Transport.(*http.Transport)
|
gotTransport := got.Client.Transport.(*http.Transport)
|
||||||
wantTransport := tt.want.Client.Transport.(*http.Transport)
|
wantTransport := tt.want.Client.Transport.(*http.Transport)
|
||||||
if gotTransport.TLSClientConfig.GetClientCertificate == nil {
|
switch {
|
||||||
|
case gotTransport.TLSClientConfig.GetClientCertificate == nil:
|
||||||
t.Error("LoadClient() transport does not define GetClientCertificate")
|
t.Error("LoadClient() transport does not define GetClientCertificate")
|
||||||
} else if !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()) {
|
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()):
|
||||||
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
|
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
|
||||||
} else {
|
default:
|
||||||
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
|
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("LoadClient() GetClientCertificate error = %v", err)
|
t.Errorf("LoadClient() GetClientCertificate error = %v", err)
|
||||||
|
|
|
@ -105,7 +105,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
|
|
||||||
tr := getDefaultTransport(tlsConfig)
|
tr := getDefaultTransport(tlsConfig)
|
||||||
// Use mutable tls.Config on renew
|
// Use mutable tls.Config on renew
|
||||||
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic
|
||||||
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
||||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
// Update renew function with transport
|
// Update renew function with transport
|
||||||
tr := getDefaultTransport(tlsConfig)
|
tr := getDefaultTransport(tlsConfig)
|
||||||
// Use mutable tls.Config on renew
|
// Use mutable tls.Config on renew
|
||||||
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck
|
tr.DialTLS = c.buildDialTLS(tlsCtx) // nolint:staticcheck,gocritic
|
||||||
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
|
||||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
|
@ -195,7 +195,7 @@ func (c *Client) buildDialTLS(ctx *TLSOptionCtx) func(network, addr string) (net
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport.
|
// buildDialTLSContext returns an implementation of DialTLSContext callback in http.Transport.
|
||||||
// nolint:unused
|
// nolint:unused,gocritic
|
||||||
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
func (c *Client) buildDialTLSContext(tlsCtx *TLSOptionCtx) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
d := getDefaultDialer()
|
d := getDefaultDialer()
|
||||||
|
@ -253,6 +253,8 @@ func TLSCertificate(sign *api.SignResponse, pk crypto.PrivateKey) (*tls.Certific
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint:gocritic
|
||||||
|
// using a new variable for clarity
|
||||||
chain := append(certPEM, caPEM...)
|
chain := append(certPEM, caPEM...)
|
||||||
cert, err := tls.X509KeyPair(chain, keyPEM)
|
cert, err := tls.X509KeyPair(chain, keyPEM)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -29,9 +29,7 @@ func init() {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var now = func() time.Time {
|
var now = time.Now
|
||||||
return time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
// The actual regular expression that matches a certificate authority is:
|
// The actual regular expression that matches a certificate authority is:
|
||||||
// ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$
|
// ^projects/[a-z][a-z0-9-]{4,28}[a-z0-9]/locations/[a-z0-9-]+/caPools/[a-zA-Z0-9-_]+/certificateAuthorities/[a-zA-Z0-9-_]+$
|
||||||
|
|
|
@ -12,7 +12,6 @@ import (
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -103,7 +102,7 @@ MHcCAQEEIN51Rgg6YcQVLeCRzumdw4pjM3VWqFIdCbnsV3Up1e/goAoGCCqGSM49
|
||||||
AwEHoUQDQgAEjJIcDhvvxi7gu4aFkiW/8+E3BfPhmhXU5RlDQusre+MHXc7XYMtk
|
AwEHoUQDQgAEjJIcDhvvxi7gu4aFkiW/8+E3BfPhmhXU5RlDQusre+MHXc7XYMtk
|
||||||
Lm6PXPeTF1DNdS21Ju1G/j1yUykGJOmxkg==
|
Lm6PXPeTF1DNdS21Ju1G/j1yUykGJOmxkg==
|
||||||
-----END EC PRIVATE KEY-----`
|
-----END EC PRIVATE KEY-----`
|
||||||
// nolint:unused,deadcode
|
// nolint:unused,deadcode,gocritic
|
||||||
testIntermediateKey = `-----BEGIN EC PRIVATE KEY-----
|
testIntermediateKey = `-----BEGIN EC PRIVATE KEY-----
|
||||||
MHcCAQEEIMMX/XkXGnRDD4fYu7Z4rHACdJn/iyOy2UTwsv+oZ0C+oAoGCCqGSM49
|
MHcCAQEEIMMX/XkXGnRDD4fYu7Z4rHACdJn/iyOy2UTwsv+oZ0C+oAoGCCqGSM49
|
||||||
AwEHoUQDQgAE8u6rGAFj5CZpdzzMogLwUyCMnp0X9wtv4OKDRcpzkYf9PU5GuGA6
|
AwEHoUQDQgAE8u6rGAFj5CZpdzzMogLwUyCMnp0X9wtv4OKDRcpzkYf9PU5GuGA6
|
||||||
|
@ -190,7 +189,7 @@ func (b *badSigner) Public() crypto.PublicKey {
|
||||||
return b.pub
|
return b.pub
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
func (b *badSigner) Sign(rnd io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
||||||
return nil, fmt.Errorf("💥")
|
return nil, fmt.Errorf("💥")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -730,7 +729,7 @@ func TestCloudCAS_RevokeCertificate(t *testing.T) {
|
||||||
func Test_createCertificateID(t *testing.T) {
|
func Test_createCertificateID(t *testing.T) {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
setTeeReader(t, buf)
|
setTeeReader(t, buf)
|
||||||
uuid, err := uuid.NewRandomFromReader(rand.Reader)
|
id, err := uuid.NewRandomFromReader(rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -741,7 +740,7 @@ func Test_createCertificateID(t *testing.T) {
|
||||||
want string
|
want string
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{"ok", uuid.String(), false},
|
{"ok", id.String(), false},
|
||||||
{"fail", "", true},
|
{"fail", "", true},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -858,7 +857,7 @@ func TestCloudCAS_CreateCertificateAuthority(t *testing.T) {
|
||||||
return lis.Dial()
|
return lis.Dial()
|
||||||
}))
|
}))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := lroauto.NewOperationsClient(context.Background(), option.WithGRPCConn(conn))
|
client, err := lroauto.NewOperationsClient(context.Background(), option.WithGRPCConn(conn))
|
||||||
|
|
|
@ -19,9 +19,7 @@ func init() {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var now = func() time.Time {
|
var now = time.Now
|
||||||
return time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SoftCAS implements a Certificate Authority Service using Golang or KMS
|
// SoftCAS implements a Certificate Authority Service using Golang or KMS
|
||||||
// crypto. This is the default CAS used in step-ca.
|
// crypto. This is the default CAS used in step-ca.
|
||||||
|
|
|
@ -133,7 +133,7 @@ func (b *badSigner) Public() crypto.PublicKey {
|
||||||
return testSigner.Public()
|
return testSigner.Public()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *badSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
|
func (b *badSigner) Sign(_ io.Reader, _ []byte, _ crypto.SignerOpts) ([]byte, error) {
|
||||||
return nil, fmt.Errorf("💥")
|
return nil, fmt.Errorf("💥")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -90,9 +90,9 @@ func (s *StepCAS) RenewCertificate(req *apiv1.RenewCertificateRequest) (*apiv1.R
|
||||||
return nil, apiv1.ErrNotImplemented{Message: "stepCAS does not support mTLS renewals"}
|
return nil, apiv1.ErrNotImplemented{Message: "stepCAS does not support mTLS renewals"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RevokeCertificate revokes a certificate.
|
||||||
func (s *StepCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
|
func (s *StepCAS) RevokeCertificate(req *apiv1.RevokeCertificateRequest) (*apiv1.RevokeCertificateResponse, error) {
|
||||||
switch {
|
if req.SerialNumber == "" && req.Certificate == nil {
|
||||||
case req.SerialNumber == "" && req.Certificate == nil:
|
|
||||||
return nil, errors.New("revokeCertificateRequest `serialNumber` or `certificate` are required")
|
return nil, errors.New("revokeCertificateRequest `serialNumber` or `certificate` are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,9 +19,7 @@ const defaultValidity = 5 * time.Minute
|
||||||
|
|
||||||
// timeNow returns the current time.
|
// timeNow returns the current time.
|
||||||
// This method is used for unit testing purposes.
|
// This method is used for unit testing purposes.
|
||||||
var timeNow = func() time.Time {
|
var timeNow = time.Now
|
||||||
return time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
type x5cIssuer struct {
|
type x5cIssuer struct {
|
||||||
caURL *url.URL
|
caURL *url.URL
|
||||||
|
|
|
@ -22,7 +22,7 @@ func (b noneSigner) Public() crypto.PublicKey {
|
||||||
return []byte(b)
|
return []byte(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b noneSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
|
func (b noneSigner) Sign(rnd io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
|
||||||
return digest, nil
|
return digest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,10 +24,10 @@ import (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var credentialsFile, region string
|
var credentialsFile, region string
|
||||||
var ssh bool
|
var enableSSH bool
|
||||||
flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the AWS KMS credentials.")
|
flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the AWS KMS credentials.")
|
||||||
flag.StringVar(®ion, "region", "", "AWS KMS region name.")
|
flag.StringVar(®ion, "region", "", "AWS KMS region name.")
|
||||||
flag.BoolVar(&ssh, "ssh", false, "Create SSH keys.")
|
flag.BoolVar(&enableSSH, "ssh", false, "Create SSH keys.")
|
||||||
flag.Usage = usage
|
flag.Usage = usage
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ func main() {
|
||||||
fatal(err)
|
fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ssh {
|
if enableSSH {
|
||||||
ui.Println()
|
ui.Println()
|
||||||
if err := createSSH(c); err != nil {
|
if err := createSSH(c); err != nil {
|
||||||
fatal(err)
|
fatal(err)
|
||||||
|
@ -120,7 +120,7 @@ func createX509(c *awskms.KMS) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
|
if err := fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: b,
|
Bytes: b,
|
||||||
}), 0600); err != nil {
|
}), 0600); err != nil {
|
||||||
|
@ -163,7 +163,7 @@ func createX509(c *awskms.KMS) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
|
if err := fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: b,
|
Bytes: b,
|
||||||
}), 0600); err != nil {
|
}), 0600); err != nil {
|
||||||
|
@ -193,7 +193,7 @@ func createSSH(c *awskms.KMS) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
if err := fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,7 +214,7 @@ func createSSH(c *awskms.KMS) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
if err := fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -191,8 +191,8 @@ var placeholderString = regexp.MustCompile(`<.*?>`)
|
||||||
|
|
||||||
func stringifyFlag(f cli.Flag) string {
|
func stringifyFlag(f cli.Flag) string {
|
||||||
fv := flagValue(f)
|
fv := flagValue(f)
|
||||||
usage := fv.FieldByName("Usage").String()
|
usg := fv.FieldByName("Usage").String()
|
||||||
placeholder := placeholderString.FindString(usage)
|
placeholder := placeholderString.FindString(usg)
|
||||||
if placeholder == "" {
|
if placeholder == "" {
|
||||||
switch f.(type) {
|
switch f.(type) {
|
||||||
case cli.BoolFlag, cli.BoolTFlag:
|
case cli.BoolFlag, cli.BoolTFlag:
|
||||||
|
@ -200,5 +200,5 @@ func stringifyFlag(f cli.Flag) string {
|
||||||
placeholder = "<value>"
|
placeholder = "<value>"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usage
|
return cli.FlagNamePrefixer(fv.FieldByName("Name").String(), placeholder) + "\t" + usg
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,13 +27,13 @@ func main() {
|
||||||
var credentialsFile string
|
var credentialsFile string
|
||||||
var project, location, ring string
|
var project, location, ring string
|
||||||
var protectionLevelName string
|
var protectionLevelName string
|
||||||
var ssh bool
|
var enableSSH bool
|
||||||
flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the Google's Cloud KMS credentials.")
|
flag.StringVar(&credentialsFile, "credentials-file", "", "Path to the `file` containing the Google's Cloud KMS credentials.")
|
||||||
flag.StringVar(&project, "project", "", "Google Cloud Project ID.")
|
flag.StringVar(&project, "project", "", "Google Cloud Project ID.")
|
||||||
flag.StringVar(&location, "location", "global", "Cloud KMS location name.")
|
flag.StringVar(&location, "location", "global", "Cloud KMS location name.")
|
||||||
flag.StringVar(&ring, "ring", "pki", "Cloud KMS ring name.")
|
flag.StringVar(&ring, "ring", "pki", "Cloud KMS ring name.")
|
||||||
flag.StringVar(&protectionLevelName, "protection-level", "SOFTWARE", "Protection level to use, SOFTWARE or HSM.")
|
flag.StringVar(&protectionLevelName, "protection-level", "SOFTWARE", "Protection level to use, SOFTWARE or HSM.")
|
||||||
flag.BoolVar(&ssh, "ssh", false, "Create SSH keys.")
|
flag.BoolVar(&enableSSH, "ssh", false, "Create SSH keys.")
|
||||||
flag.Usage = usage
|
flag.Usage = usage
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ func main() {
|
||||||
fatal(err)
|
fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ssh {
|
if enableSSH {
|
||||||
ui.Println()
|
ui.Println()
|
||||||
if err := createSSH(c, project, location, ring, protectionLevel); err != nil {
|
if err := createSSH(c, project, location, ring, protectionLevel); err != nil {
|
||||||
fatal(err)
|
fatal(err)
|
||||||
|
@ -153,7 +153,7 @@ func createPKI(c *cloudkms.CloudKMS, project, location, keyRing string, protecti
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
|
if err := fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: b,
|
Bytes: b,
|
||||||
}), 0600); err != nil {
|
}), 0600); err != nil {
|
||||||
|
@ -197,7 +197,7 @@ func createPKI(c *cloudkms.CloudKMS, project, location, keyRing string, protecti
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
|
if err := fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: b,
|
Bytes: b,
|
||||||
}), 0600); err != nil {
|
}), 0600); err != nil {
|
||||||
|
@ -230,7 +230,7 @@ func createSSH(c *cloudkms.CloudKMS, project, location, keyRing string, protecti
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
if err := fileutil.WriteFile("ssh_user_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -252,7 +252,7 @@ func createSSH(c *cloudkms.CloudKMS, project, location, keyRing string, protecti
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
if err := fileutil.WriteFile("ssh_host_ca_key.pub", ssh.MarshalAuthorizedKey(key), 0600); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -329,7 +329,7 @@ func createPKI(k kms.KeyManager, c Config) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts {
|
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts {
|
||||||
if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{
|
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
|
||||||
Name: c.RootObject,
|
Name: c.RootObject,
|
||||||
Certificate: root,
|
Certificate: root,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
@ -337,7 +337,7 @@ func createPKI(k kms.KeyManager, c Config) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{
|
if err := fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: b,
|
Bytes: b,
|
||||||
}), 0600); err != nil {
|
}), 0600); err != nil {
|
||||||
|
@ -406,7 +406,7 @@ func createPKI(k kms.KeyManager, c Config) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts {
|
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts {
|
||||||
if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{
|
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
|
||||||
Name: c.CrtObject,
|
Name: c.CrtObject,
|
||||||
Certificate: intermediate,
|
Certificate: intermediate,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
@ -414,7 +414,7 @@ func createPKI(k kms.KeyManager, c Config) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{
|
if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: b,
|
Bytes: b,
|
||||||
}), 0600); err != nil {
|
}), 0600); err != nil {
|
||||||
|
|
|
@ -228,7 +228,7 @@ func createPKI(k kms.KeyManager, c Config) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm, ok := k.(kms.CertificateManager); ok {
|
if cm, ok := k.(kms.CertificateManager); ok {
|
||||||
if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{
|
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
|
||||||
Name: c.RootSlot,
|
Name: c.RootSlot,
|
||||||
Certificate: root,
|
Certificate: root,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
@ -236,7 +236,7 @@ func createPKI(k kms.KeyManager, c Config) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
|
if err := fileutil.WriteFile("root_ca.crt", pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: b,
|
Bytes: b,
|
||||||
}), 0600); err != nil {
|
}), 0600); err != nil {
|
||||||
|
@ -305,7 +305,7 @@ func createPKI(k kms.KeyManager, c Config) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if cm, ok := k.(kms.CertificateManager); ok {
|
if cm, ok := k.(kms.CertificateManager); ok {
|
||||||
if err = cm.StoreCertificate(&apiv1.StoreCertificateRequest{
|
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
|
||||||
Name: c.CrtSlot,
|
Name: c.CrtSlot,
|
||||||
Certificate: intermediate,
|
Certificate: intermediate,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
|
@ -313,7 +313,7 @@ func createPKI(k kms.KeyManager, c Config) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
|
if err := fileutil.WriteFile("intermediate_ca.crt", pem.EncodeToMemory(&pem.Block{
|
||||||
Type: "CERTIFICATE",
|
Type: "CERTIFICATE",
|
||||||
Bytes: b,
|
Bytes: b,
|
||||||
}), 0600); err != nil {
|
}), 0600); err != nil {
|
||||||
|
|
|
@ -79,13 +79,13 @@ func appAction(ctx *cli.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
configFile := ctx.Args().Get(0)
|
configFile := ctx.Args().Get(0)
|
||||||
config, err := config.LoadConfiguration(configFile)
|
cfg, err := config.LoadConfiguration(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fatal(err)
|
fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.AuthorityConfig != nil {
|
if cfg.AuthorityConfig != nil {
|
||||||
if token == "" && strings.EqualFold(config.AuthorityConfig.DeploymentType, pki.LinkedDeployment.String()) {
|
if token == "" && strings.EqualFold(cfg.AuthorityConfig.DeploymentType, pki.LinkedDeployment.String()) {
|
||||||
return errors.New(`'step-ca' requires the '--token' flag for linked deploy type.
|
return errors.New(`'step-ca' requires the '--token' flag for linked deploy type.
|
||||||
|
|
||||||
To get a linked authority token:
|
To get a linked authority token:
|
||||||
|
@ -136,7 +136,7 @@ To get a linked authority token:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
srv, err := ca.New(config,
|
srv, err := ca.New(cfg,
|
||||||
ca.WithConfigFile(configFile),
|
ca.WithConfigFile(configFile),
|
||||||
ca.WithPassword(password),
|
ca.WithPassword(password),
|
||||||
ca.WithSSHHostPassword(sshHostPassword),
|
ca.WithSSHHostPassword(sshHostPassword),
|
||||||
|
|
|
@ -63,11 +63,11 @@ func exportAction(ctx *cli.Context) error {
|
||||||
passwordFile := ctx.String("password-file")
|
passwordFile := ctx.String("password-file")
|
||||||
issuerPasswordFile := ctx.String("issuer-password-file")
|
issuerPasswordFile := ctx.String("issuer-password-file")
|
||||||
|
|
||||||
config, err := config.LoadConfiguration(configFile)
|
cfg, err := config.LoadConfiguration(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := config.Validate(); err != nil {
|
if err := cfg.Validate(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,19 +76,19 @@ func exportAction(ctx *cli.Context) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(err, "error reading %s", passwordFile)
|
return errors.Wrapf(err, "error reading %s", passwordFile)
|
||||||
}
|
}
|
||||||
config.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace))
|
cfg.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace))
|
||||||
}
|
}
|
||||||
if issuerPasswordFile != "" {
|
if issuerPasswordFile != "" {
|
||||||
b, err := ioutil.ReadFile(issuerPasswordFile)
|
b, err := ioutil.ReadFile(issuerPasswordFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(err, "error reading %s", issuerPasswordFile)
|
return errors.Wrapf(err, "error reading %s", issuerPasswordFile)
|
||||||
}
|
}
|
||||||
if config.AuthorityConfig.CertificateIssuer != nil {
|
if cfg.AuthorityConfig.CertificateIssuer != nil {
|
||||||
config.AuthorityConfig.CertificateIssuer.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace))
|
cfg.AuthorityConfig.CertificateIssuer.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auth, err := authority.New(config)
|
auth, err := authority.New(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -103,8 +103,8 @@ func onboardAction(ctx *cli.Context) error {
|
||||||
return errors.Wrap(msg, "error receiving onboarding guide")
|
return errors.Wrap(msg, "error receiving onboarding guide")
|
||||||
}
|
}
|
||||||
|
|
||||||
var config onboardingConfiguration
|
var cfg onboardingConfiguration
|
||||||
if err := readJSON(res.Body, &config); err != nil {
|
if err := readJSON(res.Body, &cfg); err != nil {
|
||||||
return errors.Wrap(err, "error unmarshaling response")
|
return errors.Wrap(err, "error unmarshaling response")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,16 +112,16 @@ func onboardAction(ctx *cli.Context) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
config.password = []byte(password)
|
cfg.password = []byte(password)
|
||||||
|
|
||||||
ui.Println("Initializing step-ca with the following configuration:")
|
ui.Println("Initializing step-ca with the following configuration:")
|
||||||
ui.PrintSelected("Name", config.Name)
|
ui.PrintSelected("Name", cfg.Name)
|
||||||
ui.PrintSelected("DNS", config.DNS)
|
ui.PrintSelected("DNS", cfg.DNS)
|
||||||
ui.PrintSelected("Address", config.Address)
|
ui.PrintSelected("Address", cfg.Address)
|
||||||
ui.PrintSelected("Password", password)
|
ui.PrintSelected("Password", password)
|
||||||
ui.Println()
|
ui.Println()
|
||||||
|
|
||||||
caConfig, fp, err := onboardPKI(config)
|
caConfig, fp, err := onboardPKI(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -149,23 +149,23 @@ func onboardAction(ctx *cli.Context) error {
|
||||||
ui.Println("Initialized!")
|
ui.Println("Initialized!")
|
||||||
ui.Println("Step CA is starting. Please return to the onboarding guide in your browser to continue.")
|
ui.Println("Step CA is starting. Please return to the onboarding guide in your browser to continue.")
|
||||||
|
|
||||||
srv, err := ca.New(caConfig, ca.WithPassword(config.password))
|
srv, err := ca.New(caConfig, ca.WithPassword(cfg.password))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fatal(err)
|
fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go ca.StopReloaderHandler(srv)
|
go ca.StopReloaderHandler(srv)
|
||||||
if err = srv.Run(); err != nil && err != http.ErrServerClosed {
|
if err := srv.Run(); err != nil && err != http.ErrServerClosed {
|
||||||
fatal(err)
|
fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func onboardPKI(config onboardingConfiguration) (*config.Config, string, error) {
|
func onboardPKI(cfg onboardingConfiguration) (*config.Config, string, error) {
|
||||||
var opts = []pki.Option{
|
var opts = []pki.Option{
|
||||||
pki.WithAddress(config.Address),
|
pki.WithAddress(cfg.Address),
|
||||||
pki.WithDNSNames([]string{config.DNS}),
|
pki.WithDNSNames([]string{cfg.DNS}),
|
||||||
pki.WithProvisioner("admin"),
|
pki.WithProvisioner("admin"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,25 +179,25 @@ func onboardPKI(config onboardingConfiguration) (*config.Config, string, error)
|
||||||
|
|
||||||
// Generate pki
|
// Generate pki
|
||||||
ui.Println("Generating root certificate...")
|
ui.Println("Generating root certificate...")
|
||||||
root, err := p.GenerateRootCertificate(config.Name, config.Name, config.Name, config.password)
|
root, err := p.GenerateRootCertificate(cfg.Name, cfg.Name, cfg.Name, cfg.password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
ui.Println("Generating intermediate certificate...")
|
ui.Println("Generating intermediate certificate...")
|
||||||
err = p.GenerateIntermediateCertificate(config.Name, config.Name, config.Name, root, config.password)
|
err = p.GenerateIntermediateCertificate(cfg.Name, cfg.Name, cfg.Name, root, cfg.password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write files to disk
|
// Write files to disk
|
||||||
if err = p.WriteFiles(); err != nil {
|
if err := p.WriteFiles(); err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate provisioner
|
// Generate provisioner
|
||||||
ui.Println("Generating admin provisioner...")
|
ui.Println("Generating admin provisioner...")
|
||||||
if err = p.GenerateKeyPairs(config.password); err != nil {
|
if err := p.GenerateKeyPairs(cfg.password); err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,7 +211,7 @@ func onboardPKI(config onboardingConfiguration) (*config.Config, string, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", errors.Wrapf(err, "error marshaling %s", p.GetCAConfigPath())
|
return nil, "", errors.Wrapf(err, "error marshaling %s", p.GetCAConfigPath())
|
||||||
}
|
}
|
||||||
if err = fileutil.WriteFile(p.GetCAConfigPath(), b, 0666); err != nil {
|
if err := fileutil.WriteFile(p.GetCAConfigPath(), b, 0666); err != nil {
|
||||||
return nil, "", errs.FileError(err, p.GetCAConfigPath())
|
return nil, "", errs.FileError(err, p.GetCAConfigPath())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -144,15 +144,15 @@ func TestUseToken(t *testing.T) {
|
||||||
}
|
}
|
||||||
for name, tc := range tests {
|
for name, tc := range tests {
|
||||||
t.Run(name, func(t *testing.T) {
|
t.Run(name, func(t *testing.T) {
|
||||||
ok, err := tc.db.UseToken(tc.id, tc.tok)
|
switch ok, err := tc.db.UseToken(tc.id, tc.tok); {
|
||||||
if err != nil {
|
case err != nil:
|
||||||
if assert.NotNil(t, tc.want.err) {
|
if assert.NotNil(t, tc.want.err) {
|
||||||
assert.HasPrefix(t, err.Error(), tc.want.err.Error())
|
assert.HasPrefix(t, err.Error(), tc.want.err.Error())
|
||||||
}
|
}
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
} else if ok {
|
case ok:
|
||||||
assert.True(t, tc.want.ok)
|
assert.True(t, tc.want.ok)
|
||||||
} else {
|
default:
|
||||||
assert.False(t, tc.want.ok)
|
assert.False(t, tc.want.ok)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -166,10 +166,10 @@ func (s *privateKey) Delete() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *privateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
|
func (s *privateKey) Decrypt(rnd io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
|
||||||
k, ok := s.Signer.(*rsa.PrivateKey)
|
k, ok := s.Signer.(*rsa.PrivateKey)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("key is not an rsa key")
|
return nil, errors.New("key is not an rsa key")
|
||||||
}
|
}
|
||||||
return k.Decrypt(rand, msg, opts)
|
return k.Decrypt(rnd, msg, opts)
|
||||||
}
|
}
|
||||||
|
|
|
@ -145,8 +145,7 @@ func (k *PKCS11) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons
|
||||||
// CreateSigner creates a signer using the key present in the PKCS#11 MODULE signature
|
// CreateSigner creates a signer using the key present in the PKCS#11 MODULE signature
|
||||||
// slot.
|
// slot.
|
||||||
func (k *PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
func (k *PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) {
|
||||||
switch {
|
if req.SigningKey == "" {
|
||||||
case req.SigningKey == "":
|
|
||||||
return nil, errors.New("createSignerRequest 'signingKey' cannot be empty")
|
return nil, errors.New("createSignerRequest 'signingKey' cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -204,8 +203,8 @@ func (k *PKCS11) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteKey is a utility function to delete a key given an uri.
|
// DeleteKey is a utility function to delete a key given an uri.
|
||||||
func (k *PKCS11) DeleteKey(uri string) error {
|
func (k *PKCS11) DeleteKey(u string) error {
|
||||||
id, object, err := parseObject(uri)
|
id, object, err := parseObject(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "deleteKey failed")
|
return errors.Wrap(err, "deleteKey failed")
|
||||||
}
|
}
|
||||||
|
@ -223,8 +222,8 @@ func (k *PKCS11) DeleteKey(uri string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteCertificate is a utility function to delete a certificate given an uri.
|
// DeleteCertificate is a utility function to delete a certificate given an uri.
|
||||||
func (k *PKCS11) DeleteCertificate(uri string) error {
|
func (k *PKCS11) DeleteCertificate(u string) error {
|
||||||
id, object, err := parseObject(uri)
|
id, object, err := parseObject(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "deleteCertificate failed")
|
return errors.Wrap(err, "deleteCertificate failed")
|
||||||
}
|
}
|
||||||
|
|
|
@ -378,6 +378,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) {
|
||||||
t.Errorf("SSHAgentKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("SSHAgentKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// nolint:gocritic
|
||||||
switch s := got.(type) {
|
switch s := got.(type) {
|
||||||
case *WrappedSSHSigner:
|
case *WrappedSSHSigner:
|
||||||
gotPkS := s.Sshsigner.PublicKey().(*agent.Key).String() + "\n"
|
gotPkS := s.Sshsigner.PublicKey().(*agent.Key).String() + "\n"
|
||||||
|
@ -562,6 +563,7 @@ func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
|
||||||
t.Errorf("SSHAgentKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("SSHAgentKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// nolint:gocritic
|
||||||
switch tt.want.(type) {
|
switch tt.want.(type) {
|
||||||
case ssh.PublicKey:
|
case ssh.PublicKey:
|
||||||
// If we want a ssh.PublicKey, protote got to a
|
// If we want a ssh.PublicKey, protote got to a
|
||||||
|
|
55
pki/pki.go
55
pki/pki.go
|
@ -128,7 +128,7 @@ func GetTemplatesPath() string {
|
||||||
|
|
||||||
// GetProvisioners returns the map of provisioners on the given CA.
|
// GetProvisioners returns the map of provisioners on the given CA.
|
||||||
func GetProvisioners(caURL, rootFile string) (provisioner.List, error) {
|
func GetProvisioners(caURL, rootFile string) (provisioner.List, error) {
|
||||||
if len(rootFile) == 0 {
|
if rootFile == "" {
|
||||||
rootFile = GetRootCAPath()
|
rootFile = GetRootCAPath()
|
||||||
}
|
}
|
||||||
client, err := ca.NewClient(caURL, ca.WithRootFile(rootFile))
|
client, err := ca.NewClient(caURL, ca.WithRootFile(rootFile))
|
||||||
|
@ -153,7 +153,7 @@ func GetProvisioners(caURL, rootFile string) (provisioner.List, error) {
|
||||||
// GetProvisionerKey returns the encrypted provisioner key with the for the
|
// GetProvisionerKey returns the encrypted provisioner key with the for the
|
||||||
// given kid.
|
// given kid.
|
||||||
func GetProvisionerKey(caURL, rootFile, kid string) (string, error) {
|
func GetProvisionerKey(caURL, rootFile, kid string) (string, error) {
|
||||||
if len(rootFile) == 0 {
|
if rootFile == "" {
|
||||||
rootFile = GetRootCAPath()
|
rootFile = GetRootCAPath()
|
||||||
}
|
}
|
||||||
client, err := ca.NewClient(caURL, ca.WithRootFile(rootFile))
|
client, err := ca.NewClient(caURL, ca.WithRootFile(rootFile))
|
||||||
|
@ -315,17 +315,17 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
|
||||||
|
|
||||||
// Use /home/step as the step path in helm configurations.
|
// Use /home/step as the step path in helm configurations.
|
||||||
// Use the current step path when creating pki in files.
|
// Use the current step path when creating pki in files.
|
||||||
var public, private, config string
|
var public, private, cfg string
|
||||||
if p.options.isHelm {
|
if p.options.isHelm {
|
||||||
public = "/home/step/certs"
|
public = "/home/step/certs"
|
||||||
private = "/home/step/secrets"
|
private = "/home/step/secrets"
|
||||||
config = "/home/step/config"
|
cfg = "/home/step/config"
|
||||||
} else {
|
} else {
|
||||||
public = GetPublicPath()
|
public = GetPublicPath()
|
||||||
private = GetSecretsPath()
|
private = GetSecretsPath()
|
||||||
config = GetConfigPath()
|
cfg = GetConfigPath()
|
||||||
// Create directories
|
// Create directories
|
||||||
dirs := []string{public, private, config, GetTemplatesPath()}
|
dirs := []string{public, private, cfg, GetTemplatesPath()}
|
||||||
for _, name := range dirs {
|
for _, name := range dirs {
|
||||||
if _, err := os.Stat(name); os.IsNotExist(err) {
|
if _, err := os.Stat(name); os.IsNotExist(err) {
|
||||||
if err = os.MkdirAll(name, 0700); err != nil {
|
if err = os.MkdirAll(name, 0700); err != nil {
|
||||||
|
@ -380,10 +380,10 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
|
||||||
if p.Ssh.UserKey, err = getPath(private, "ssh_user_ca_key"); err != nil {
|
if p.Ssh.UserKey, err = getPath(private, "ssh_user_ca_key"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if p.defaults, err = getPath(config, "defaults.json"); err != nil {
|
if p.defaults, err = getPath(cfg, "defaults.json"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if p.config, err = getPath(config, "ca.json"); err != nil {
|
if p.config, err = getPath(cfg, "ca.json"); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
p.Defaults.CaConfig = p.config
|
p.Defaults.CaConfig = p.config
|
||||||
|
@ -620,16 +620,17 @@ func (p *PKI) askFeedback() {
|
||||||
|
|
||||||
func (p *PKI) tellPKI() {
|
func (p *PKI) tellPKI() {
|
||||||
ui.Println()
|
ui.Println()
|
||||||
if p.casOptions.Is(apiv1.SoftCAS) {
|
switch {
|
||||||
|
case p.casOptions.Is(apiv1.SoftCAS):
|
||||||
ui.PrintSelected("Root certificate", p.Root[0])
|
ui.PrintSelected("Root certificate", p.Root[0])
|
||||||
ui.PrintSelected("Root private key", p.RootKey[0])
|
ui.PrintSelected("Root private key", p.RootKey[0])
|
||||||
ui.PrintSelected("Root fingerprint", p.Defaults.Fingerprint)
|
ui.PrintSelected("Root fingerprint", p.Defaults.Fingerprint)
|
||||||
ui.PrintSelected("Intermediate certificate", p.Intermediate)
|
ui.PrintSelected("Intermediate certificate", p.Intermediate)
|
||||||
ui.PrintSelected("Intermediate private key", p.IntermediateKey)
|
ui.PrintSelected("Intermediate private key", p.IntermediateKey)
|
||||||
} else if p.Defaults.Fingerprint != "" {
|
case p.Defaults.Fingerprint != "":
|
||||||
ui.PrintSelected("Root certificate", p.Root[0])
|
ui.PrintSelected("Root certificate", p.Root[0])
|
||||||
ui.PrintSelected("Root fingerprint", p.Defaults.Fingerprint)
|
ui.PrintSelected("Root fingerprint", p.Defaults.Fingerprint)
|
||||||
} else {
|
default:
|
||||||
ui.Printf(`{{ "%s" | red }} {{ "Root certificate:" | bold }} failed to retrieve it from RA`+"\n", ui.IconBad)
|
ui.Printf(`{{ "%s" | red }} {{ "Root certificate:" | bold }} failed to retrieve it from RA`+"\n", ui.IconBad)
|
||||||
}
|
}
|
||||||
if p.options.enableSSH {
|
if p.options.enableSSH {
|
||||||
|
@ -657,7 +658,7 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
|
||||||
authorityOptions = &p.casOptions
|
authorityOptions = &p.casOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &authconfig.Config{
|
cfg := &authconfig.Config{
|
||||||
Root: p.Root,
|
Root: p.Root,
|
||||||
FederatedRoots: p.FederatedRoots,
|
FederatedRoots: p.FederatedRoots,
|
||||||
IntermediateCert: p.Intermediate,
|
IntermediateCert: p.Intermediate,
|
||||||
|
@ -681,7 +682,7 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
|
||||||
// Add linked as a deployment type to detect it on start and provide a
|
// Add linked as a deployment type to detect it on start and provide a
|
||||||
// message if the token is not given.
|
// message if the token is not given.
|
||||||
if p.options.deploymentType == LinkedDeployment {
|
if p.options.deploymentType == LinkedDeployment {
|
||||||
config.AuthorityConfig.DeploymentType = LinkedDeployment.String()
|
cfg.AuthorityConfig.DeploymentType = LinkedDeployment.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// On standalone deployments add the provisioners to either the ca.json or
|
// On standalone deployments add the provisioners to either the ca.json or
|
||||||
|
@ -711,7 +712,7 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
|
||||||
|
|
||||||
if p.options.enableSSH {
|
if p.options.enableSSH {
|
||||||
enableSSHCA := true
|
enableSSHCA := true
|
||||||
config.SSH = &authconfig.SSHConfig{
|
cfg.SSH = &authconfig.SSHConfig{
|
||||||
HostKey: p.Ssh.HostKey,
|
HostKey: p.Ssh.HostKey,
|
||||||
UserKey: p.Ssh.UserKey,
|
UserKey: p.Ssh.UserKey,
|
||||||
}
|
}
|
||||||
|
@ -733,19 +734,19 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
|
||||||
|
|
||||||
// Apply configuration modifiers
|
// Apply configuration modifiers
|
||||||
for _, o := range opt {
|
for _, o := range opt {
|
||||||
if err := o(config); err != nil {
|
if err := o(cfg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set authority.enableAdmin to true
|
// Set authority.enableAdmin to true
|
||||||
if p.options.enableAdmin {
|
if p.options.enableAdmin {
|
||||||
config.AuthorityConfig.EnableAdmin = true
|
cfg.AuthorityConfig.EnableAdmin = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.options.deploymentType == StandaloneDeployment {
|
if p.options.deploymentType == StandaloneDeployment {
|
||||||
if !config.AuthorityConfig.EnableAdmin {
|
if !cfg.AuthorityConfig.EnableAdmin {
|
||||||
config.AuthorityConfig.Provisioners = provisioners
|
cfg.AuthorityConfig.Provisioners = provisioners
|
||||||
} else {
|
} else {
|
||||||
// At this moment this code path is never used because `step ca
|
// At this moment this code path is never used because `step ca
|
||||||
// init` will always set enableAdmin to false for a standalone
|
// init` will always set enableAdmin to false for a standalone
|
||||||
|
@ -754,11 +755,11 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
|
||||||
//
|
//
|
||||||
// Note that we might want to be able to define the database as a
|
// Note that we might want to be able to define the database as a
|
||||||
// flag in `step ca init` so we can write to the proper place.
|
// flag in `step ca init` so we can write to the proper place.
|
||||||
db, err := db.New(config.DB)
|
_db, err := db.New(cfg.DB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
adminDB, err := admindb.New(db.(nosql.DB), admin.DefaultAuthorityID)
|
adminDB, err := admindb.New(_db.(nosql.DB), admin.DefaultAuthorityID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -788,7 +789,7 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return config, nil
|
return cfg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save stores the pki on a json file that will be used as the certificate
|
// Save stores the pki on a json file that will be used as the certificate
|
||||||
|
@ -804,12 +805,12 @@ func (p *PKI) Save(opt ...ConfigOption) error {
|
||||||
|
|
||||||
// Generate and write ca.json
|
// Generate and write ca.json
|
||||||
if !p.options.pkiOnly {
|
if !p.options.pkiOnly {
|
||||||
config, err := p.GenerateConfig(opt...)
|
cfg, err := p.GenerateConfig(opt...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := json.MarshalIndent(config, "", "\t")
|
b, err := json.MarshalIndent(cfg, "", "\t")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(err, "error marshaling %s", p.config)
|
return errors.Wrapf(err, "error marshaling %s", p.config)
|
||||||
}
|
}
|
||||||
|
@ -833,14 +834,14 @@ func (p *PKI) Save(opt ...ConfigOption) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate and write templates
|
// Generate and write templates
|
||||||
if err := generateTemplates(config.Templates); err != nil {
|
if err := generateTemplates(cfg.Templates); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.DB != nil {
|
if cfg.DB != nil {
|
||||||
ui.PrintSelected("Database folder", config.DB.DataSource)
|
ui.PrintSelected("Database folder", cfg.DB.DataSource)
|
||||||
}
|
}
|
||||||
if config.Templates != nil {
|
if cfg.Templates != nil {
|
||||||
ui.PrintSelected("Templates folder", GetTemplatesPath())
|
ui.PrintSelected("Templates folder", GetTemplatesPath())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -198,14 +198,14 @@ func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
provisioner, ok := p.(*provisioner.SCEP)
|
prov, ok := p.(*provisioner.SCEP)
|
||||||
if !ok {
|
if !ok {
|
||||||
api.WriteError(w, errors.New("provisioner must be of type SCEP"))
|
api.WriteError(w, errors.New("provisioner must be of type SCEP"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(provisioner))
|
ctx = context.WithValue(ctx, scep.ProvisionerContextKey, scep.Provisioner(prov))
|
||||||
next(w, r.WithContext(ctx))
|
next(w, r.WithContext(ctx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"github.com/Masterminds/sprig/v3"
|
"github.com/Masterminds/sprig/v3"
|
||||||
|
@ -226,16 +227,13 @@ func (t *Template) Output(data interface{}) (Output, error) {
|
||||||
|
|
||||||
// backfill updates old templates with the required data.
|
// backfill updates old templates with the required data.
|
||||||
func (t *Template) backfill(b []byte) {
|
func (t *Template) backfill(b []byte) {
|
||||||
switch t.Name {
|
if strings.EqualFold(t.Name, "sshd_config.tpl") && len(t.RequiredData) == 0 {
|
||||||
case "sshd_config.tpl":
|
|
||||||
if len(t.RequiredData) == 0 {
|
|
||||||
a := bytes.TrimSpace(b)
|
a := bytes.TrimSpace(b)
|
||||||
b := bytes.TrimSpace([]byte(DefaultSSHTemplateData[t.Name]))
|
b := bytes.TrimSpace([]byte(DefaultSSHTemplateData[t.Name]))
|
||||||
if bytes.Equal(a, b) {
|
if bytes.Equal(a, b) {
|
||||||
t.RequiredData = []string{"Certificate", "Key"}
|
t.RequiredData = []string{"Certificate", "Key"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output represents the text representation of a rendered template.
|
// Output represents the text representation of a rendered template.
|
||||||
|
|
Loading…
Reference in a new issue