Merge branch 'master' into hs/acme-eab

This commit is contained in:
Herman Slatman 2021-12-06 13:01:23 +01:00
commit d0c23973cc
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
83 changed files with 795 additions and 621 deletions

View file

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: [ '1.15', '1.16', '1.17' ] go: [ '1.16', '1.17' ]
steps: steps:
- -
name: Checkout name: Checkout
@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'latest' version: 'v1.43.0'
# Optional: working directory, useful for monorepos # Optional: working directory, useful for monorepos
# working-directory: somedir # working-directory: somedir
@ -58,6 +58,7 @@ jobs:
run: V=1 make ci run: V=1 make ci
- -
name: Codecov name: Codecov
if: matrix.go == '1.17'
uses: codecov/codecov-action@v1.2.1 uses: codecov/codecov-action@v1.2.1
with: with:
file: ./coverage.out # optional file: ./coverage.out # optional

View file

@ -73,9 +73,3 @@ issues:
- error strings should not be capitalized or end with punctuation or a newline - error strings should not be capitalized or end with punctuation or a newline
- Wrapf call needs 1 arg but has 2 args - Wrapf call needs 1 arg but has 2 args
- cs.NegotiatedProtocolIsMutual is deprecated - cs.NegotiatedProtocolIsMutual is deprecated
# golangci.com configuration
# https://github.com/golangci/golangci/wiki/Configuration
service:
golangci-lint-version: 1.19.x # use the fixed version to not introduce new linters unexpectedly
prepare:
- echo "here I can run custom commands, but no preparation needed for this repo"

View file

@ -4,15 +4,23 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased - 0.17.7] - DATE ## [Unreleased - 0.18.1] - DATE
### Added ### Added
- Support for generate extractable keys and certificates on a pkcs#11 module.
### Changed ### Changed
### Deprecated ### Deprecated
### Removed ### Removed
### Fixed ### Fixed
### Security ### Security
## [0.18.0] - 2021-11-17
### Added
- Support for multiple certificate authority contexts.
- Support for generating extractable keys and certificates on a pkcs#11 module.
### Changed
- Support two latest versions of golang (1.16, 1.17)
### Deprecated
- go 1.15 support
## [0.17.6] - 2021-10-20 ## [0.17.6] - 2021-10-20
### Notes ### Notes
- 0.17.5 failed in CI/CD - 0.17.5 failed in CI/CD

View file

@ -11,8 +11,9 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"math/big" "math/big"
"io"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
@ -130,7 +131,7 @@ func jwkEncode(pub crypto.PublicKey) (string, error) {
e := big.NewInt(int64(pub.E)) e := big.NewInt(int64(pub.E))
// Field order is important. // Field order is important.
// See https://tools.ietf.org/html/rfc7638#section-3.3 for details. // See https://tools.ietf.org/html/rfc7638#section-3.3 for details.
return fmt.Sprintf(`{"e":"%s","kty":"RSA","n":"%s"}`, return fmt.Sprintf(`{"e":%q,"kty":"RSA","n":%q}`,
base64.RawURLEncoding.EncodeToString(e.Bytes()), base64.RawURLEncoding.EncodeToString(e.Bytes()),
base64.RawURLEncoding.EncodeToString(n.Bytes()), base64.RawURLEncoding.EncodeToString(n.Bytes()),
), nil ), nil
@ -151,7 +152,7 @@ func jwkEncode(pub crypto.PublicKey) (string, error) {
} }
// Field order is important. // Field order is important.
// See https://tools.ietf.org/html/rfc7638#section-3.3 for details. // See https://tools.ietf.org/html/rfc7638#section-3.3 for details.
return fmt.Sprintf(`{"crv":"%s","kty":"EC","x":"%s","y":"%s"}`, return fmt.Sprintf(`{"crv":%q,"kty":"EC","x":%q,"y":%q}`,
p.Name, p.Name,
base64.RawURLEncoding.EncodeToString(x), base64.RawURLEncoding.EncodeToString(x),
base64.RawURLEncoding.EncodeToString(y), base64.RawURLEncoding.EncodeToString(y),
@ -402,7 +403,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -722,7 +723,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -922,7 +923,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -17,7 +17,7 @@ import (
) )
func link(url, typ string) string { func link(url, typ string) string {
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ) return fmt.Sprintf("<%s>;rel=%q", url, typ)
} }
// Clock that returns time in UTC rounded to seconds. // Clock that returns time in UTC rounded to seconds.

View file

@ -7,7 +7,7 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -89,7 +89,7 @@ func TestHandler_GetDirectory(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -261,7 +261,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -404,7 +404,7 @@ func TestHandler_GetCertificate(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -660,7 +660,7 @@ func TestHandler_GetChallenge(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"crypto/rsa" "crypto/rsa"
"errors" "errors"
"io/ioutil" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -118,7 +118,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct. // parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
func (h *Handler) parseJWS(next nextHTTP) nextHTTP { func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body"))
return return
@ -378,7 +378,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
} }
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
value: payload, value: payload,
isPostAsGet: string(payload) == "", isPostAsGet: len(payload) == 0,
isEmptyJSON: string(payload) == "{}", isEmptyJSON: string(payload) == "{}",
}) })
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))

View file

@ -8,7 +8,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -148,7 +147,7 @@ func TestHandler_addNonce(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -205,7 +204,7 @@ func TestHandler_addDirLink(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -332,7 +331,7 @@ func TestHandler_verifyContentType(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -400,7 +399,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -490,7 +489,7 @@ func TestHandler_parseJWS(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -689,7 +688,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -891,7 +890,7 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1087,7 +1086,7 @@ func TestHandler_extractJWK(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1454,7 +1453,7 @@ func TestHandler_validateJWS(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -7,7 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect" "reflect"
@ -430,7 +430,7 @@ func TestHandler_GetOrder(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1343,7 +1343,7 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1633,7 +1633,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -12,7 +12,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -92,7 +92,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
"error doing http GET for url %s with status code %d", u, resp.StatusCode)) "error doing http GET for url %s with status code %d", u, resp.StatusCode))
} }
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return WrapErrorISE(err, "error reading "+ return WrapErrorISE(err, "error reading "+
"response body for url %s", u) "response body for url %s", u)

View file

@ -15,7 +15,6 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math/big" "math/big"
"net" "net"
"net/http" "net/http"
@ -707,7 +706,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
}, },
}, },
@ -733,7 +732,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
}, },
}, },
@ -775,7 +774,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
}, },
}, },
@ -818,7 +817,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
}, },
}, },
@ -860,7 +859,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
}, },
}, },

View file

@ -318,7 +318,7 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r) cursor, limit, err := ParseCursor(r)
if err != nil { if err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
@ -435,7 +435,7 @@ func ParseCursor(r *http.Request) (cursor string, limit int, err error) {
if v := q.Get("limit"); len(v) > 0 { if v := q.Get("limit"); len(v) > 0 {
limit, err = strconv.Atoi(v) limit, err = strconv.Atoi(v)
if err != nil { if err != nil {
return "", 0, errors.Wrapf(err, "error converting %s to integer", v) return "", 0, errs.BadRequestErr(err, "limit '%s' is not an integer", v)
} }
} }
return return

View file

@ -16,7 +16,7 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io"
"math/big" "math/big"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -788,7 +788,7 @@ func Test_caHandler_Health(t *testing.T) {
t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode) t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Health unexpected error = %v", err) t.Errorf("caHandler.Health unexpected error = %v", err)
@ -829,7 +829,7 @@ func Test_caHandler_Root(t *testing.T) {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err) t.Errorf("caHandler.Root unexpected error = %v", err)
@ -902,7 +902,7 @@ func Test_caHandler_Sign(t *testing.T) {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err) t.Errorf("caHandler.Root unexpected error = %v", err)
@ -954,7 +954,7 @@ func Test_caHandler_Renew(t *testing.T) {
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Renew unexpected error = %v", err) t.Errorf("caHandler.Renew unexpected error = %v", err)
@ -1015,7 +1015,7 @@ func Test_caHandler_Rekey(t *testing.T) {
t.Errorf("caHandler.Rekey StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Rekey StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Rekey unexpected error = %v", err) t.Errorf("caHandler.Rekey unexpected error = %v", err)
@ -1038,12 +1038,12 @@ func Test_caHandler_Provisioners(t *testing.T) {
r *http.Request r *http.Request
} }
req, err := http.NewRequest("GET", "http://example.com/provisioners?cursor=foo&limit=20", nil) req, err := http.NewRequest("GET", "http://example.com/provisioners?cursor=foo&limit=20", http.NoBody)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
reqLimitFail, err := http.NewRequest("GET", "http://example.com/provisioners?limit=abc", nil) reqLimitFail, err := http.NewRequest("GET", "http://example.com/provisioners?limit=abc", http.NoBody)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1087,7 +1087,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
expectedError400 := errs.BadRequest("force") expectedError400 := errs.BadRequest("limit 'abc' is not an integer")
expectedError400Bytes, err := json.Marshal(expectedError400) expectedError400Bytes, err := json.Marshal(expectedError400)
assert.FatalError(t, err) assert.FatalError(t, err)
expectedError500 := errs.InternalServer("force") expectedError500 := errs.InternalServer("force")
@ -1105,7 +1105,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Provisioners unexpected error = %v", err) t.Errorf("caHandler.Provisioners unexpected error = %v", err)
@ -1175,7 +1175,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Provisioners unexpected error = %v", err) t.Errorf("caHandler.Provisioners unexpected error = %v", err)
@ -1225,7 +1225,7 @@ func Test_caHandler_Roots(t *testing.T) {
t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Roots unexpected error = %v", err) t.Errorf("caHandler.Roots unexpected error = %v", err)
@ -1271,7 +1271,7 @@ func Test_caHandler_Federation(t *testing.T) {
t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Federation unexpected error = %v", err) t.Errorf("caHandler.Federation unexpected error = %v", err)

View file

@ -18,7 +18,7 @@ func (s *RekeyRequest) Validate() error {
return errs.BadRequest("missing csr") return errs.BadRequest("missing csr")
} }
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "invalid csr") return errs.BadRequestErr(err, "invalid csr")
} }
return nil return nil
@ -26,15 +26,14 @@ func (s *RekeyRequest) Validate() error {
// Rekey is similar to renew except that the certificate will be renewed with new key from csr. // Rekey is similar to renew except that the certificate will be renewed with new key from csr.
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing peer certificate")) WriteError(w, errs.BadRequest("missing client certificate"))
return return
} }
var body RekeyRequest var body RekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -10,7 +10,7 @@ import (
// new one. // new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing peer certificate")) WriteError(w, errs.BadRequest("missing client certificate"))
return return
} }

View file

@ -49,7 +49,7 @@ func (r *RevokeRequest) Validate() (err error) {
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
@ -80,12 +80,12 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
// the client certificate Serial Number must match the serial number // the client certificate Serial Number must match the serial number
// being revoked. // being revoked.
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
WriteError(w, errs.BadRequest("missing ott or peer certificate")) WriteError(w, errs.BadRequest("missing ott or client certificate"))
return return
} }
opts.Crt = r.TLS.PeerCertificates[0] opts.Crt = r.TLS.PeerCertificates[0]
if opts.Crt.SerialNumber.String() != opts.Serial { if opts.Crt.SerialNumber.String() != opts.Serial {
WriteError(w, errs.BadRequest("revoke: serial number in mtls certificate different than body")) WriteError(w, errs.BadRequest("serial number in client certificate different than body"))
return return
} }
// TODO: should probably be checking if the certificate was revoked here. // TODO: should probably be checking if the certificate was revoked here.

View file

@ -6,7 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -233,7 +233,7 @@ func Test_caHandler_Revoke(t *testing.T) {
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -26,7 +26,7 @@ func (s *SignRequest) Validate() error {
return errs.BadRequest("missing csr") return errs.BadRequest("missing csr")
} }
if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil { if err := s.CsrPEM.CertificateRequest.CheckSignature(); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "invalid csr") return errs.BadRequestErr(err, "invalid csr")
} }
if s.OTT == "" { if s.OTT == "" {
return errs.BadRequest("missing ott") return errs.BadRequest("missing ott")
@ -50,7 +50,7 @@ type SignResponse struct {
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -49,16 +49,16 @@ type SSHSignRequest struct {
func (s *SSHSignRequest) Validate() error { func (s *SSHSignRequest) Validate() error {
switch { switch {
case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert: case s.CertType != "" && s.CertType != provisioner.SSHUserCert && s.CertType != provisioner.SSHHostCert:
return errors.Errorf("unknown certType %s", s.CertType) return errs.BadRequest("invalid certType '%s'", s.CertType)
case len(s.PublicKey) == 0: case len(s.PublicKey) == 0:
return errors.New("missing or empty publicKey") return errs.BadRequest("missing or empty publicKey")
case s.OTT == "": case s.OTT == "":
return errors.New("missing or empty ott") return errs.BadRequest("missing or empty ott")
default: default:
// Validate identity signature if provided // Validate identity signature if provided
if s.IdentityCSR.CertificateRequest != nil { if s.IdentityCSR.CertificateRequest != nil {
if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil { if err := s.IdentityCSR.CertificateRequest.CheckSignature(); err != nil {
return errors.Wrap(err, "invalid identityCSR") return errs.BadRequestErr(err, "invalid identityCSR")
} }
} }
return nil return nil
@ -185,7 +185,7 @@ func (r *SSHConfigRequest) Validate() error {
case provisioner.SSHUserCert, provisioner.SSHHostCert: case provisioner.SSHUserCert, provisioner.SSHHostCert:
return nil return nil
default: default:
return errors.Errorf("unsupported type %s", r.Type) return errs.BadRequest("invalid type '%s'", r.Type)
} }
} }
@ -208,9 +208,9 @@ type SSHCheckPrincipalRequest struct {
func (r *SSHCheckPrincipalRequest) Validate() error { func (r *SSHCheckPrincipalRequest) Validate() error {
switch { switch {
case r.Type != provisioner.SSHHostCert: case r.Type != provisioner.SSHHostCert:
return errors.Errorf("unsupported type %s", r.Type) return errs.BadRequest("unsupported type '%s'", r.Type)
case r.Principal == "": case r.Principal == "":
return errors.New("missing or empty principal") return errs.BadRequest("missing or empty principal")
default: default:
return nil return nil
} }
@ -232,7 +232,7 @@ type SSHBastionRequest struct {
// Validate checks the values of the SSHBastionRequest. // Validate checks the values of the SSHBastionRequest.
func (r *SSHBastionRequest) Validate() error { func (r *SSHBastionRequest) Validate() error {
if r.Hostname == "" { if r.Hostname == "" {
return errors.New("missing or empty hostname") return errs.BadRequest("missing or empty hostname")
} }
return nil return nil
} }
@ -250,19 +250,19 @@ type SSHBastionResponse struct {
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey")) WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
return return
} }
@ -270,7 +270,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if body.AddUserPublicKey != nil { if body.AddUserPublicKey != nil {
addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey) addUserPublicKey, err = ssh.ParsePublicKey(body.AddUserPublicKey)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing addUserPublicKey")) WriteError(w, errs.BadRequestErr(err, "error parsing addUserPublicKey"))
return return
} }
} }
@ -394,11 +394,11 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
@ -426,11 +426,11 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest var body SSHCheckPrincipalRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
@ -465,11 +465,11 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }

View file

@ -4,7 +4,6 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -20,9 +19,9 @@ type SSHRekeyRequest struct {
func (s *SSHRekeyRequest) Validate() error { func (s *SSHRekeyRequest) Validate() error {
switch { switch {
case s.OTT == "": case s.OTT == "":
return errors.New("missing or empty ott") return errs.BadRequest("missing or empty ott")
case len(s.PublicKey) == 0: case len(s.PublicKey) == 0:
return errors.New("missing or empty public key") return errs.BadRequest("missing or empty public key")
default: default:
return nil return nil
} }
@ -40,19 +39,19 @@ type SSHRekeyResponse struct {
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }
publicKey, err := ssh.ParsePublicKey(body.PublicKey) publicKey, err := ssh.ParsePublicKey(body.PublicKey)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error parsing publicKey")) WriteError(w, errs.BadRequestErr(err, "error parsing publicKey"))
return return
} }

View file

@ -19,7 +19,7 @@ type SSHRenewRequest struct {
func (s *SSHRenewRequest) Validate() error { func (s *SSHRenewRequest) Validate() error {
switch { switch {
case s.OTT == "": case s.OTT == "":
return errors.New("missing or empty ott") return errs.BadRequest("missing or empty ott")
default: default:
return nil return nil
} }
@ -37,13 +37,13 @@ type SSHRenewResponse struct {
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
logOtt(w, body.OTT) logOtt(w, body.OTT)
if err := body.Validate(); err != nil { if err := body.Validate(); err != nil {
WriteError(w, errs.BadRequestErr(err)) WriteError(w, err)
return return
} }

View file

@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := ReadJSON(r.Body, &body); err != nil {
WriteError(w, errs.Wrap(http.StatusBadRequest, err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -10,7 +10,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
@ -299,14 +299,14 @@ func Test_caHandler_SSHSign(t *testing.T) {
body []byte body []byte
statusCode int statusCode int
}{ }{
{"ok-user", userReq, nil, user, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated}, {"ok-user", userReq, nil, user, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q}`, userB64)), http.StatusCreated},
{"ok-host", hostReq, nil, host, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated}, {"ok-host", hostReq, nil, host, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q}`, hostB64)), http.StatusCreated},
{"ok-user-add", userAddReq, nil, user, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated}, {"ok-user-add", userAddReq, nil, user, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q,"addUserCrt":%q}`, userB64, userB64)), http.StatusCreated},
{"ok-user-identity", userIdentityReq, nil, user, nil, user, nil, identityCerts, nil, []byte(fmt.Sprintf(`{"crt":"%s","identityCrt":[%s]}`, userB64, identityCertsPEM)), http.StatusCreated}, {"ok-user-identity", userIdentityReq, nil, user, nil, user, nil, identityCerts, nil, []byte(fmt.Sprintf(`{"crt":%q,"identityCrt":[%s]}`, userB64, identityCertsPEM)), http.StatusCreated},
{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":%q,"ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, nil, nil, http.StatusUnauthorized}, {"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, nil, nil, http.StatusUnauthorized},
{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusForbidden}, {"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusForbidden},
{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden}, {"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
@ -338,7 +338,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
t.Errorf("caHandler.SignSSH StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SignSSH StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SignSSH unexpected error = %v", err) t.Errorf("caHandler.SignSSH unexpected error = %v", err)
@ -368,10 +368,10 @@ func Test_caHandler_SSHRoots(t *testing.T) {
body []byte body []byte
statusCode int statusCode int
}{ }{
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK}, {"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q],"hostKey":[%q]}`, userB64, hostB64)), http.StatusOK},
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK}, {"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q,%q],"hostKey":[%q,%q]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK}, {"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q]}`, userB64)), http.StatusOK},
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK}, {"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":[%q]}`, hostB64)), http.StatusOK},
{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound}, {"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
} }
@ -392,7 +392,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
t.Errorf("caHandler.SSHRoots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHRoots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHRoots unexpected error = %v", err) t.Errorf("caHandler.SSHRoots unexpected error = %v", err)
@ -422,10 +422,10 @@ func Test_caHandler_SSHFederation(t *testing.T) {
body []byte body []byte
statusCode int statusCode int
}{ }{
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK}, {"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q],"hostKey":[%q]}`, userB64, hostB64)), http.StatusOK},
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK}, {"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q,%q],"hostKey":[%q,%q]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK}, {"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q]}`, userB64)), http.StatusOK},
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK}, {"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":[%q]}`, hostB64)), http.StatusOK},
{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound}, {"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
} }
@ -446,7 +446,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
t.Errorf("caHandler.SSHFederation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHFederation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHFederation unexpected error = %v", err) t.Errorf("caHandler.SSHFederation unexpected error = %v", err)
@ -506,7 +506,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
t.Errorf("caHandler.SSHConfig StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHConfig StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHConfig unexpected error = %v", err) t.Errorf("caHandler.SSHConfig unexpected error = %v", err)
@ -553,7 +553,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
t.Errorf("caHandler.SSHCheckHost StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHCheckHost StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHCheckHost unexpected error = %v", err) t.Errorf("caHandler.SSHCheckHost unexpected error = %v", err)
@ -604,7 +604,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
t.Errorf("caHandler.SSHGetHosts StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHGetHosts StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHGetHosts unexpected error = %v", err) t.Errorf("caHandler.SSHGetHosts unexpected error = %v", err)
@ -659,7 +659,7 @@ func Test_caHandler_SSHBastion(t *testing.T) {
t.Errorf("caHandler.SSHBastion StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHBastion StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHBastion unexpected error = %v", err) t.Errorf("caHandler.SSHBastion unexpected error = %v", err)

View file

@ -3,7 +3,6 @@ package api
import ( import (
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil"
"log" "log"
"net/http" "net/http"
@ -94,7 +93,7 @@ func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
// pointed by v. // pointed by v.
func ReadJSON(r io.Reader, v interface{}) error { func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil { if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.Wrap(http.StatusBadRequest, err, "error decoding json") return errs.BadRequestErr(err, "error decoding json")
} }
return nil return nil
} }
@ -102,9 +101,9 @@ func ReadJSON(r io.Reader, v interface{}) error {
// ReadProtoJSON reads JSON from the request body and stores it in the value // ReadProtoJSON reads JSON from the request body and stores it in the value
// pointed by v. // pointed by v.
func ReadProtoJSON(r io.Reader, m proto.Message) error { func ReadProtoJSON(r io.Reader, m proto.Message) error {
data, err := ioutil.ReadAll(r) data, err := io.ReadAll(r)
if err != nil { if err != nil {
return errs.Wrap(http.StatusBadRequest, err, "error reading request body") return errs.BadRequestErr(err, "error reading request body")
} }
return protojson.Unmarshal(data, m) return protojson.Unmarshal(data, m)
} }

View file

@ -7,8 +7,8 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -195,7 +195,7 @@ func TestAuthority_GetDatabase(t *testing.T) {
} }
func TestNewEmbedded(t *testing.T) { func TestNewEmbedded(t *testing.T) {
caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") caPEM, err := os.ReadFile("testdata/certs/root_ca.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt")
@ -268,7 +268,7 @@ func TestNewEmbedded(t *testing.T) {
} }
func TestNewEmbedded_Sign(t *testing.T) { func TestNewEmbedded_Sign(t *testing.T) {
caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") caPEM, err := os.ReadFile("testdata/certs/root_ca.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt")
@ -294,7 +294,7 @@ func TestNewEmbedded_Sign(t *testing.T) {
} }
func TestNewEmbedded_GetTLSCertificate(t *testing.T) { func TestNewEmbedded_GetTLSCertificate(t *testing.T) {
caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") caPEM, err := os.ReadFile("testdata/certs/root_ca.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt")

View file

@ -2,14 +2,14 @@ package authority
import ( import (
"encoding/json" "encoding/json"
"io/ioutil"
"net/url" "net/url"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
) )
@ -245,7 +245,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
return "" return ""
} }
stepPath := filepath.ToSlash(config.StepPath()) stepPath := filepath.ToSlash(step.Path())
if !strings.HasSuffix(stepPath, "/") { if !strings.HasSuffix(stepPath, "/") {
stepPath += "/" stepPath += "/"
} }
@ -257,7 +257,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
panic(err) panic(err)
} }
if ok { if ok {
b, err := ioutil.ReadFile(config.StepAbs(fn)) b, err := os.ReadFile(step.Abs(fn))
if err != nil { if err != nil {
panic(errors.Wrapf(err, "error reading %s", fn)) panic(errors.Wrapf(err, "error reading %s", fn))
} }

View file

@ -9,9 +9,10 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io"
"net" "net"
"net/http" "net/http"
"os"
"strings" "strings"
"time" "time"
@ -165,7 +166,7 @@ func newAWSConfig(certPath string) (*awsConfig, error) {
if certPath == "" { if certPath == "" {
certBytes = []byte(awsCertificate) certBytes = []byte(awsCertificate)
} else { } else {
if b, err := ioutil.ReadFile(certPath); err == nil { if b, err := os.ReadFile(certPath); err == nil {
certBytes = b certBytes = b
} else { } else {
return nil, errors.Wrapf(err, "error reading %s", certPath) return nil, errors.Wrapf(err, "error reading %s", certPath)
@ -569,7 +570,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
client := http.Client{} client := http.Client{}
// first get the token // first get the token
req, err := http.NewRequest(http.MethodPut, p.config.tokenURL, nil) req, err := http.NewRequest(http.MethodPut, p.config.tokenURL, http.NoBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -582,7 +583,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode) return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode)
} }
token, err := ioutil.ReadAll(resp.Body) token, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -602,7 +603,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
func (p *AWS) readResponseBody(resp *http.Response) ([]byte, error) { func (p *AWS) readResponseBody(resp *http.Response) ([]byte, error) {
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,7 +6,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"io/ioutil" "io"
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
@ -173,7 +173,7 @@ func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error reading identity token response") return "", errors.Wrap(err, "error reading identity token response")
} }

View file

@ -107,7 +107,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
w.Write([]byte(t1)) w.Write([]byte(t1))
default: default:
w.Header().Add("Content-Type", "application/json") w.Header().Add("Content-Type", "application/json")
w.Write([]byte(fmt.Sprintf(`{"access_token":"%s"}`, t1))) fmt.Fprintf(w, `{"access_token":"%s"}`, t1)
} }
})) }))
defer srv.Close() defer srv.Close()

View file

@ -7,7 +7,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -183,7 +183,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
return "", errors.Wrap(err, "error doing identity request, are you in a GCP VM?") return "", errors.Wrap(err, "error doing identity request, are you in a GCP VM?")
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error on identity request") return "", errors.Wrap(err, "error on identity request")
} }

View file

@ -228,7 +228,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Use options in the token. // Use options in the token.
if opts.CertType != "" { if opts.CertType != "" {
if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil { if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "jwk.AuthorizeSSHSign") return nil, errs.BadRequestErr(err, err.Error())
} }
} }
if opts.KeyID != "" { if opts.KeyID != "" {

View file

@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
) )
@ -370,7 +371,7 @@ func newValidityValidator(min, max time.Duration) *validityValidator {
} }
// Valid validates the certificate validity settings (notBefore/notAfter) and // Valid validates the certificate validity settings (notBefore/notAfter) and
// and total duration. // total duration.
func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
var ( var (
na = cert.NotAfter.Truncate(time.Second) na = cert.NotAfter.Truncate(time.Second)
@ -381,22 +382,20 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
d := na.Sub(nb) d := na.Sub(nb)
if na.Before(now) { if na.Before(now) {
return errors.Errorf("notAfter cannot be in the past; na=%v", na) return errs.BadRequest("notAfter cannot be in the past; na=%v", na)
} }
if na.Before(nb) { if na.Before(nb) {
return errors.Errorf("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) return errs.BadRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
} }
if d < v.min { if d < v.min {
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v", return errs.BadRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min)
d, v.min)
} }
// NOTE: this check is not "technically correct". We're allowing the max // NOTE: this check is not "technically correct". We're allowing the max
// duration of a cert to be "max + backdate" and not all certificates will // duration of a cert to be "max + backdate" and not all certificates will
// be backdated (e.g. if a user passes the NotBefore value then we do not // be backdated (e.g. if a user passes the NotBefore value then we do not
// apply a backdate). This is good enough. // apply a backdate). This is good enough.
if d > v.max+o.Backdate { if d > v.max+o.Backdate {
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v", return errs.BadRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate)
d, v.max+o.Backdate)
} }
return nil return nil
} }

View file

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -55,7 +56,7 @@ type SignSSHOptions struct {
// Validate validates the given SignSSHOptions. // Validate validates the given SignSSHOptions.
func (o SignSSHOptions) Validate() error { func (o SignSSHOptions) Validate() error {
if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert { if o.CertType != "" && o.CertType != SSHUserCert && o.CertType != SSHHostCert {
return errors.Errorf("unknown certType %s", o.CertType) return errs.BadRequest("unknown certificate type '%s'", o.CertType)
} }
return nil return nil
} }
@ -335,11 +336,11 @@ type sshCertValidityValidator struct {
func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error { func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error {
switch { switch {
case cert.ValidAfter == 0: case cert.ValidAfter == 0:
return errors.New("ssh certificate validAfter cannot be 0") return errs.BadRequest("ssh certificate validAfter cannot be 0")
case cert.ValidBefore < uint64(now().Unix()): case cert.ValidBefore < uint64(now().Unix()):
return errors.New("ssh certificate validBefore cannot be in the past") return errs.BadRequest("ssh certificate validBefore cannot be in the past")
case cert.ValidBefore < cert.ValidAfter: case cert.ValidBefore < cert.ValidAfter:
return errors.New("ssh certificate validBefore cannot be before validAfter") return errs.BadRequest("ssh certificate validBefore cannot be before validAfter")
} }
var min, max time.Duration var min, max time.Duration
@ -351,9 +352,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
min = v.MinHostSSHCertDuration() min = v.MinHostSSHCertDuration()
max = v.MaxHostSSHCertDuration() max = v.MaxHostSSHCertDuration()
case 0: case 0:
return errors.New("ssh certificate type has not been set") return errs.BadRequest("ssh certificate type has not been set")
default: default:
return errors.Errorf("unknown ssh certificate type %d", cert.CertType) return errs.BadRequest("unknown ssh certificate type %d", cert.CertType)
} }
// To not take into account the backdate, time.Now() will be used to // To not take into account the backdate, time.Now() will be used to
@ -362,11 +363,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
switch { switch {
case dur < min: case dur < min:
return errors.Errorf("requested duration of %s is less than minimum "+ return errs.BadRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min)
"accepted duration for selected provisioner of %s", dur, min)
case dur > max+opts.Backdate: case dur > max+opts.Backdate:
return errors.Errorf("requested duration of %s is greater than maximum "+ return errs.BadRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
"accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
default: default:
return nil return nil
} }

View file

@ -191,8 +191,7 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
} }
if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) { if claims.Subject != strconv.FormatUint(claims.sshCert.Serial, 10) {
return errs.BadRequest("sshpop.AuthorizeSSHRevoke; sshpop token subject " + return errs.BadRequest("sshpop token subject must be equivalent to sshpop certificate serial number")
"must be equivalent to sshpop certificate serial number")
} }
return nil return nil
} }
@ -205,7 +204,7 @@ func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Cert
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
} }
if claims.sshCert.CertType != ssh.HostCert { if claims.sshCert.CertType != ssh.HostCert {
return nil, errs.BadRequest("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate") return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, nil return claims.sshCert, nil
@ -220,7 +219,7 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
} }
if claims.sshCert.CertType != ssh.HostCert { if claims.sshCert.CertType != ssh.HostCert {
return nil, nil, errs.BadRequest("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate") return nil, nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, []SignOption{ return claims.sshCert, []SignOption{
// Validate public key // Validate public key

View file

@ -258,7 +258,7 @@ func TestSSHPOP_AuthorizeSSHRevoke(t *testing.T) {
p: p, p: p,
token: tok, token: tok,
code: http.StatusBadRequest, code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRevoke; sshpop token subject must be equivalent to sshpop certificate serial number"), err: errors.New("sshpop token subject must be equivalent to sshpop certificate serial number"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -337,7 +337,7 @@ func TestSSHPOP_AuthorizeSSHRenew(t *testing.T) {
p: p, p: p,
token: tok, token: tok,
code: http.StatusBadRequest, code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRenew; sshpop certificate must be a host ssh certificate"), err: errors.New("sshpop certificate must be a host ssh certificate"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -419,7 +419,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
p: p, p: p,
token: tok, token: tok,
code: http.StatusBadRequest, code: http.StatusBadRequest,
err: errors.New("sshpop.AuthorizeSSHRekey; sshpop certificate must be a host ssh certificate"), err: errors.New("sshpop certificate must be a host ssh certificate"),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {

View file

@ -10,9 +10,9 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"strings" "strings"
"time" "time"
@ -188,7 +188,7 @@ func generateJWK() (*JWK, error) {
} }
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
fooPubB, err := ioutil.ReadFile("./testdata/certs/foo.pub") fooPubB, err := os.ReadFile("./testdata/certs/foo.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -196,7 +196,7 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
barPubB, err := ioutil.ReadFile("./testdata/certs/bar.pub") barPubB, err := os.ReadFile("./testdata/certs/bar.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -234,7 +234,7 @@ func generateSSHPOP() (*SSHPOP, error) {
return nil, err return nil, err
} }
userB, err := ioutil.ReadFile("./testdata/certs/ssh_user_ca_key.pub") userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -242,7 +242,7 @@ func generateSSHPOP() (*SSHPOP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
hostB, err := ioutil.ReadFile("./testdata/certs/ssh_host_ca_key.pub") hostB, err := os.ReadFile("./testdata/certs/ssh_host_ca_key.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -271,7 +271,7 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Use options in the token. // Use options in the token.
if opts.CertType != "" { if opts.CertType != "" {
if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil { if certType, err = sshutil.CertTypeFromString(opts.CertType); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "x5c.AuthorizeSSHSign") return nil, errs.BadRequestErr(err, err.Error())
} }
} }
if opts.KeyID != "" { if opts.KeyID != "" {

View file

@ -6,14 +6,14 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
step "go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/linkedca" "go.step.sm/linkedca"
@ -238,6 +238,8 @@ func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error {
return nil return nil
} }
// CreateFirstProvisioner creates and stores the first provisioner when using
// admin database provisioner storage.
func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) { func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) {
if password == "" { if password == "" {
pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one") pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one")
@ -287,6 +289,7 @@ func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (
return p, nil return p, nil
} }
// ValidateClaims validates the Claims type.
func ValidateClaims(c *linkedca.Claims) error { func ValidateClaims(c *linkedca.Claims) error {
if c == nil { if c == nil {
return nil return nil
@ -313,6 +316,7 @@ func ValidateClaims(c *linkedca.Claims) error {
return nil return nil
} }
// ValidateDurations validates the Durations type.
func ValidateDurations(d *linkedca.Durations) error { func ValidateDurations(d *linkedca.Durations) error {
var ( var (
err error err error
@ -523,8 +527,8 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.X509.Template != "" { if p.X509.Template != "" {
x509Template.Template = []byte(p.SSH.Template) x509Template.Template = []byte(p.SSH.Template)
} else if p.X509.TemplateFile != "" { } else if p.X509.TemplateFile != "" {
filename := step.StepAbs(p.X509.TemplateFile) filename := step.Abs(p.X509.TemplateFile)
if x509Template.Template, err = ioutil.ReadFile(filename); err != nil { if x509Template.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading x509 template") return nil, nil, errors.Wrap(err, "error reading x509 template")
} }
} }
@ -539,8 +543,8 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.SSH.Template != "" { if p.SSH.Template != "" {
sshTemplate.Template = []byte(p.SSH.Template) sshTemplate.Template = []byte(p.SSH.Template)
} else if p.SSH.TemplateFile != "" { } else if p.SSH.TemplateFile != "" {
filename := step.StepAbs(p.SSH.TemplateFile) filename := step.Abs(p.SSH.TemplateFile)
if sshTemplate.Template, err = ioutil.ReadFile(filename); err != nil { if sshTemplate.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading ssh template") return nil, nil, errors.Wrap(err, "error reading ssh template")
} }
} }

View file

@ -69,7 +69,7 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
ts = a.templates.SSH.Host ts = a.templates.SSH.Host
} }
default: default:
return nil, errs.BadRequest("getSSHConfig: type %s is not valid", typ) return nil, errs.BadRequest("invalid certificate type '%s'", typ)
} }
// Merge user and default data // Merge user and default data
@ -94,13 +94,22 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
// Check for required variables. // Check for required variables.
if err := t.ValidateRequiredData(data); err != nil { if err := t.ValidateRequiredData(data); err != nil {
return nil, errs.BadRequestErr(err, errs.WithMessage("%v, please use `--set <key=value>` flag", err)) return nil, errs.BadRequestErr(err, "%v, please use `--set <key=value>` flag", err)
} }
o, err := t.Output(mergedData) o, err := t.Output(mergedData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Backwards compatibility for version of the cli older than v0.18.0.
// Before v0.18.0 we were not passing any value for SSHTemplateVersionKey
// from the cli.
if o.Name == "step_includes.tpl" && data[templates.SSHTemplateVersionKey] == "" {
o.Type = templates.File
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/")
}
output = append(output, o) output = append(output, o)
} }
return output, nil return output, nil
@ -142,7 +151,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// Validate given options. // Validate given options.
if err := opts.Validate(); err != nil { if err := opts.Validate(); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH") return nil, err
} }
// Set backdate with the configured value // Set backdate with the configured value
@ -185,8 +194,8 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
certificate, err := sshutil.NewCertificate(cr, certOptions...) certificate, err := sshutil.NewCertificate(cr, certOptions...)
if err != nil { if err != nil {
if _, ok := err.(*sshutil.TemplateError); ok { if _, ok := err.(*sshutil.TemplateError); ok {
return nil, errs.NewErr(http.StatusBadRequest, err, return nil, errs.ApplyOptions(
errs.WithMessage(err.Error()), errs.BadRequestErr(err, err.Error()),
errs.WithKeyVal("signOptions", signOpts), errs.WithKeyVal("signOptions", signOpts),
) )
} }
@ -199,7 +208,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// Use SignSSHOptions to modify the certificate validity. It will be later // Use SignSSHOptions to modify the certificate validity. It will be later
// checked or set if not defined. // checked or set if not defined.
if err := opts.ModifyValidity(certTpl); err != nil { if err := opts.ModifyValidity(certTpl); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.SignSSH") return nil, errs.BadRequestErr(err, err.Error())
} }
// Use provisioner modifiers. // Use provisioner modifiers.
@ -249,7 +258,7 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
// RenewSSH creates a signed SSH certificate using the old SSH certificate as a template. // RenewSSH creates a signed SSH certificate using the old SSH certificate as a template.
func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) { func (a *Authority) RenewSSH(ctx context.Context, oldCert *ssh.Certificate) (*ssh.Certificate, error) {
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errs.BadRequest("renewSSH: cannot renew certificate without validity period") return nil, errs.BadRequest("cannot renew a certificate without validity period")
} }
if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil {
@ -320,7 +329,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
} }
if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 { if oldCert.ValidAfter == 0 || oldCert.ValidBefore == 0 {
return nil, errs.BadRequest("rekeySSH; cannot rekey certificate without validity period") return nil, errs.BadRequest("cannot rekey a certificate without validity period")
} }
if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil { if err := a.authorizeSSHCertificate(ctx, oldCert); err != nil {
@ -360,7 +369,7 @@ func (a *Authority) RekeySSH(ctx context.Context, oldCert *ssh.Certificate, pub
} }
signer = a.sshCAHostCertSignKey signer = a.sshCAHostCertSignKey
default: default:
return nil, errs.BadRequest("rekeySSH; unexpected ssh certificate type: %d", cert.CertType) return nil, errs.BadRequest("unexpected certificate type '%d'", cert.CertType)
} }
var err error var err error

View file

@ -501,6 +501,32 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")}, {Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")},
} }
tmplConfigUserIncludes := &templates.Templates{
SSH: &templates.SSHTemplates{
User: []templates.Template{
{Name: "step_includes.tpl", Type: templates.PrependLine, TemplatePath: "./testdata/templates/step_includes.tpl", Path: "${STEPPATH}/ssh/includes", Comment: "#"},
},
},
Data: map[string]interface{}{
"Step": &templates.Step{
SSH: templates.StepSSH{
UserKey: user,
HostKey: host,
},
},
},
}
userOutputEmptyData := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"<no value>/ssh/config\"\n")},
}
userOutputWithoutTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
userOutputWithTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.PrependLine, Comment: "#", Path: "${STEPPATH}/ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
tmplConfigErr := &templates.Templates{ tmplConfigErr := &templates.Templates{
SSH: &templates.SSHTemplates{ SSH: &templates.SSHTemplates{
User: []templates.Template{ User: []templates.Template{
@ -542,6 +568,9 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false}, {"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false},
{"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false}, {"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false},
{"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false}, {"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false},
{"userIncludesEmptyData", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", nil}, userOutputEmptyData, false},
{"userIncludesWithoutTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithoutTemplateVersion, false},
{"userIncludesWithTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step", "StepSSHTemplateVersion": "v2"}}, userOutputWithTemplateVersion, false},
{"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true}, {"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true},
{"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true}, {"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true},
{"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true}, {"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true},
@ -883,7 +912,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{}, cert: &ssh.Certificate{},
key: pub, key: pub,
signOpts: []provisioner.SignOption{}, signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; cannot rekey certificate without validity period"), err: errors.New("cannot rekey a certificate without validity period"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
} }
}, },
@ -894,7 +923,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())}, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix())},
key: pub, key: pub,
signOpts: []provisioner.SignOption{}, signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; cannot rekey certificate without validity period"), err: errors.New("cannot rekey a certificate without validity period"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
} }
}, },
@ -927,7 +956,7 @@ func TestAuthority_RekeySSH(t *testing.T) {
cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0}, cert: &ssh.Certificate{ValidAfter: uint64(now.Unix()), ValidBefore: uint64(now.Add(10 * time.Minute).Unix()), CertType: 0},
key: pub, key: pub,
signOpts: []provisioner.SignOption{}, signOpts: []provisioner.SignOption{},
err: errors.New("rekeySSH; unexpected ssh certificate type: 0"), err: errors.New("unexpected certificate type '0'"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
} }
}, },

View file

@ -0,0 +1 @@
{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}

View file

@ -76,7 +76,10 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
opts := []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)} opts := []interface{}{errs.WithKeyVal("csr", csr), errs.WithKeyVal("signOptions", signOpts)}
if err := csr.CheckSignature(); err != nil { if err := csr.CheckSignature(); err != nil {
return nil, errs.Wrap(http.StatusBadRequest, err, "authority.Sign; invalid certificate request", opts...) return nil, errs.ApplyOptions(
errs.BadRequestErr(err, "invalid certificate request"),
opts...,
)
} }
// Set backdate with the configured value // Set backdate with the configured value
@ -114,8 +117,8 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
cert, err := x509util.NewCertificate(csr, certOptions...) cert, err := x509util.NewCertificate(csr, certOptions...)
if err != nil { if err != nil {
if _, ok := err.(*x509util.TemplateError); ok { if _, ok := err.(*x509util.TemplateError); ok {
return nil, errs.NewErr(http.StatusBadRequest, err, return nil, errs.ApplyOptions(
errs.WithMessage(err.Error()), errs.BadRequestErr(err, err.Error()),
errs.WithKeyVal("csr", csr), errs.WithKeyVal("csr", csr),
errs.WithKeyVal("signOptions", signOpts), errs.WithKeyVal("signOptions", signOpts),
) )
@ -433,8 +436,10 @@ func (a *Authority) Revoke(ctx context.Context, revokeOpts *RevokeOptions) error
case db.ErrNotImplemented: case db.ErrNotImplemented:
return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...) return errs.NotImplemented("authority.Revoke; no persistence layer configured", opts...)
case db.ErrAlreadyExists: case db.ErrAlreadyExists:
return errs.BadRequest("authority.Revoke; certificate with serial "+ return errs.ApplyOptions(
"number %s has already been revoked", append([]interface{}{rci.Serial}, opts...)...) errs.BadRequest("certificate with serial number '%s' is already revoked", rci.Serial),
opts...,
)
default: default:
return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.Revoke", opts...)
} }

View file

@ -256,7 +256,7 @@ func TestAuthority_Sign(t *testing.T) {
csr: csr, csr: csr,
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: signOpts, signOpts: signOpts,
err: errors.New("authority.Sign; invalid certificate request"), err: errors.New("invalid certificate request"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
} }
}, },
@ -310,7 +310,7 @@ func TestAuthority_Sign(t *testing.T) {
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: _signOpts, signOpts: _signOpts,
err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"), err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"),
code: http.StatusUnauthorized, code: http.StatusBadRequest,
} }
}, },
"fail validate sans when adding common name not in claims": func(t *testing.T) *signTest { "fail validate sans when adding common name not in claims": func(t *testing.T) *signTest {
@ -538,15 +538,15 @@ ZYtQ9Ot36qc=
if tc.csr.Subject.CommonName == "" { if tc.csr.Subject.CommonName == "" {
assert.Equals(t, leaf.Subject, pkix.Name{}) assert.Equals(t, leaf.Subject, pkix.Name{})
} else { } else {
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization}, Organization: []string{tmplt.Organization},
Locality: []string{tmplt.Locality}, Locality: []string{tmplt.Locality},
StreetAddress: []string{tmplt.StreetAddress}, StreetAddress: []string{tmplt.StreetAddress},
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: "smallstep test", CommonName: "smallstep test",
})) }.String())
assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"}) assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"})
} }
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
@ -718,15 +718,15 @@ func TestAuthority_Renew(t *testing.T) {
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization}, Organization: []string{tmplt.Organization},
Locality: []string{tmplt.Locality}, Locality: []string{tmplt.Locality},
StreetAddress: []string{tmplt.StreetAddress}, StreetAddress: []string{tmplt.StreetAddress},
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: tmplt.CommonName, CommonName: tmplt.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
@ -925,15 +925,15 @@ func TestAuthority_Rekey(t *testing.T) {
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization}, Organization: []string{tmplt.Organization},
Locality: []string{tmplt.Locality}, Locality: []string{tmplt.Locality},
StreetAddress: []string{tmplt.StreetAddress}, StreetAddress: []string{tmplt.StreetAddress},
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: tmplt.CommonName, CommonName: tmplt.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
@ -1187,7 +1187,7 @@ func TestAuthority_Revoke(t *testing.T) {
Reason: reason, Reason: reason,
OTT: raw, OTT: raw,
}, },
err: errors.New("authority.Revoke; certificate with serial number sn has already been revoked"), err: errors.New("certificate with serial number 'sn' is already revoked"),
code: http.StatusBadRequest, code: http.StatusBadRequest,
checkErrDetails: func(err *errs.Error) { checkErrDetails: func(err *errs.Error) {
assert.Equals(t, err.Details["token"], raw) assert.Equals(t, err.Details["token"], raw)

View file

@ -7,7 +7,6 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"strings" "strings"
@ -292,7 +291,7 @@ func (c *ACMEClient) GetCertificate(url string) (*x509.Certificate, []*x509.Cert
return nil, nil, readACMEError(resp.Body) return nil, nil, readACMEError(resp.Body)
} }
defer resp.Body.Close() defer resp.Body.Close()
bodyBytes, err := ioutil.ReadAll(resp.Body) bodyBytes, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "error reading GET certificate response") return nil, nil, errors.Wrap(err, "error reading GET certificate response")
} }
@ -338,7 +337,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) {
func readACMEError(r io.ReadCloser) error { func readACMEError(r io.ReadCloser) error {
defer r.Close() defer r.Close()
b, err := ioutil.ReadAll(r) b, err := io.ReadAll(r)
if err != nil { if err != nil {
return errors.Wrap(err, "error reading from body") return errors.Wrap(err, "error reading from body")
} }

View file

@ -5,7 +5,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -317,7 +317,7 @@ func TestACMEClient_post(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -455,7 +455,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -575,7 +575,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -695,7 +695,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -815,7 +815,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -936,7 +936,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1061,7 +1061,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1188,7 +1188,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1317,7 +1317,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -197,7 +197,7 @@ func (c *AdminClient) GetAdminsPaginate(opts ...AdminOption) (*adminAPI.GetAdmin
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error generating admin token") return nil, errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create GET %s request failed", u) return nil, errors.Wrapf(err, "create GET %s request failed", u)
} }
@ -284,7 +284,7 @@ func (c *AdminClient) RemoveAdmin(id string) error {
if err != nil { if err != nil {
return errors.Wrapf(err, "error generating admin token") return errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("DELETE", u.String(), nil) req, err := http.NewRequest("DELETE", u.String(), http.NoBody)
if err != nil { if err != nil {
return errors.Wrapf(err, "create DELETE %s request failed", u) return errors.Wrapf(err, "create DELETE %s request failed", u)
} }
@ -363,7 +363,7 @@ func (c *AdminClient) GetProvisioner(opts ...ProvisionerOption) (*linkedca.Provi
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error generating admin token") return nil, errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create PUT %s request failed", u) return nil, errors.Wrapf(err, "create PUT %s request failed", u)
} }
@ -402,7 +402,7 @@ func (c *AdminClient) GetProvisionersPaginate(opts ...ProvisionerOption) (*admin
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error generating admin token") return nil, errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create PUT %s request failed", u) return nil, errors.Wrapf(err, "create PUT %s request failed", u)
} }
@ -472,7 +472,7 @@ func (c *AdminClient) RemoveProvisioner(opts ...ProvisionerOption) error {
if err != nil { if err != nil {
return errors.Wrapf(err, "error generating admin token") return errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("DELETE", u.String(), nil) req, err := http.NewRequest("DELETE", u.String(), http.NoBody)
if err != nil { if err != nil {
return errors.Wrapf(err, "create DELETE %s request failed", u) return errors.Wrapf(err, "create DELETE %s request failed", u)
} }
@ -579,7 +579,7 @@ func (c *AdminClient) GetExternalAccountKeysPaginate(provisionerName, reference
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error generating admin token") return nil, errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create GET %s request failed", u) return nil, errors.Wrapf(err, "create GET %s request failed", u)
} }
@ -647,7 +647,7 @@ func (c *AdminClient) RemoveExternalAccountKey(provisionerName, keyID string) er
if err != nil { if err != nil {
return errors.Wrapf(err, "error generating admin token") return errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("DELETE", u.String(), nil) req, err := http.NewRequest("DELETE", u.String(), http.NoBody)
if err != nil { if err != nil {
return errors.Wrapf(err, "create DELETE %s request failed", u) return errors.Wrapf(err, "create DELETE %s request failed", u)
} }

View file

@ -3,7 +3,7 @@ package ca
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"io/ioutil" "io"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -382,7 +382,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
return errors.Wrapf(err, "client.Get(%s) failed", srvURL) return errors.Wrapf(err, "client.Get(%s) failed", srvURL)
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errors.Wrap(err, "client.Get() error reading response") return errors.Wrap(err, "client.Get() error reading response")
} }
@ -499,7 +499,7 @@ func TestBootstrapClientServerFederation(t *testing.T) {
return errors.Wrapf(err, "client.Get(%s) failed", srvURL) return errors.Wrapf(err, "client.Get(%s) failed", srvURL)
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errors.Wrap(err, "client.Get() error reading response") return errors.Wrap(err, "client.Get() error reading response")
} }
@ -589,9 +589,9 @@ func TestBootstrapListener(t *testing.T) {
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Errorf("ioutil.ReadAll() error = %v", err) t.Errorf("io.ReadAll() error = %v", err)
return return
} }
if string(b) != "ok" { if string(b) != "ok" {

View file

@ -115,7 +115,7 @@ func TestCASign(t *testing.T) {
ca: ca, ca: ca,
body: "invalid json", body: "invalid json",
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg, errMsg: errs.BadRequestPrefix,
} }
}, },
"fail invalid-csr-sig": func(t *testing.T) *signTest { "fail invalid-csr-sig": func(t *testing.T) *signTest {
@ -153,7 +153,7 @@ ZEp7knvU2psWRw==
ca: ca, ca: ca,
body: string(body), body: string(body),
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg, errMsg: errs.BadRequestPrefix,
} }
}, },
"fail unauthorized-ott": func(t *testing.T) *signTest { "fail unauthorized-ott": func(t *testing.T) *signTest {
@ -294,15 +294,15 @@ ZEp7knvU2psWRw==
assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second)) assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second))
assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second)) assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second))
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{asn1dn.Country}, Country: []string{asn1dn.Country},
Organization: []string{asn1dn.Organization}, Organization: []string{asn1dn.Organization},
Locality: []string{asn1dn.Locality}, Locality: []string{asn1dn.Locality},
StreetAddress: []string{asn1dn.StreetAddress}, StreetAddress: []string{asn1dn.StreetAddress},
Province: []string{asn1dn.Province}, Province: []string{asn1dn.Province},
CommonName: asn1dn.CommonName, CommonName: asn1dn.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
@ -588,7 +588,7 @@ func TestCARenew(t *testing.T) {
ca: ca, ca: ca,
tlsConnState: nil, tlsConnState: nil,
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg, errMsg: errs.BadRequestPrefix,
} }
}, },
"request-missing-peer-certificate": func(t *testing.T) *renewTest { "request-missing-peer-certificate": func(t *testing.T) *renewTest {
@ -596,7 +596,7 @@ func TestCARenew(t *testing.T) {
ca: ca, ca: ca,
tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}}, tlsConnState: &tls.ConnectionState{PeerCertificates: []*x509.Certificate{}},
status: http.StatusBadRequest, status: http.StatusBadRequest,
errMsg: errs.BadRequestDefaultMsg, errMsg: errs.BadRequestPrefix,
} }
}, },
"success": func(t *testing.T) *renewTest { "success": func(t *testing.T) *renewTest {
@ -641,10 +641,10 @@ func TestCARenew(t *testing.T) {
assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second)) assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second))
assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second)) assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second))
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
CommonName: asn1dn.CommonName, CommonName: asn1dn.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)

View file

@ -15,7 +15,6 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -29,7 +28,7 @@ import (
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/identity" "github.com/smallstep/certificates/ca/identity"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
@ -75,7 +74,7 @@ func (c *uaClient) SetTransport(tr http.RoundTripper) {
} }
func (c *uaClient) Get(u string) (*http.Response, error) { func (c *uaClient) Get(u string) (*http.Response, error) {
req, err := http.NewRequest("GET", u, nil) req, err := http.NewRequest("GET", u, http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "new request GET %s failed", u) return nil, errors.Wrapf(err, "new request GET %s failed", u)
} }
@ -226,7 +225,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
return tr, nil return tr, nil
} }
// WithTransport adds a custom transport to the Client. It will fail if a // WithTransport adds a custom transport to the Client. It will fail if a
// previous option to create the transport has been configured. // previous option to create the transport has been configured.
func WithTransport(tr http.RoundTripper) ClientOption { func WithTransport(tr http.RoundTripper) ClientOption {
return func(o *clientOptions) error { return func(o *clientOptions) error {
@ -238,6 +237,17 @@ func WithTransport(tr http.RoundTripper) ClientOption {
} }
} }
// WithInsecure adds a insecure transport that bypasses TLS verification.
func WithInsecure() ClientOption {
return func(o *clientOptions) error {
o.transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
return nil
}
}
// WithRootFile will create the transport using the given root certificate. It // WithRootFile will create the transport using the given root certificate. It
// will fail if a previous option to create the transport has been configured. // will fail if a previous option to create the transport has been configured.
func WithRootFile(filename string) ClientOption { func WithRootFile(filename string) ClientOption {
@ -350,7 +360,7 @@ func WithRetryFunc(fn RetryFunc) ClientOption {
} }
func getTransportFromFile(filename string) (http.RoundTripper, error) { func getTransportFromFile(filename string) (http.RoundTripper, error) {
data, err := ioutil.ReadFile(filename) data, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename) return nil, errors.Wrapf(err, "error reading %s", filename)
} }
@ -652,7 +662,7 @@ retry:
// verify the sha256 // verify the sha256
sum := sha256.Sum256(root.RootPEM.Raw) sum := sha256.Sum256(root.RootPEM.Raw)
if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) { if !strings.EqualFold(sha256Sum, strings.ToLower(hex.EncodeToString(sum[:]))) {
return nil, errs.BadRequest("client.Root; root certificate SHA256 fingerprint do not match") return nil, errs.BadRequest("root certificate fingerprint does not match")
} }
return &root, nil return &root, nil
} }
@ -1098,8 +1108,7 @@ retry:
retried = true retried = true
goto retry goto retry
} }
return nil, readError(resp.Body)
return nil, errs.StatusCodeError(resp.StatusCode, readError(resp.Body))
} }
var check api.SSHCheckPrincipalResponse var check api.SSHCheckPrincipalResponse
if err := readJSON(resp.Body, &check); err != nil { if err := readJSON(resp.Body, &check); err != nil {
@ -1295,7 +1304,7 @@ func createCertificateRequest(commonName string, sans []string, key crypto.Priva
// getRootCAPath returns the path where the root CA is stored based on the // getRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable. // STEPPATH environment variable.
func getRootCAPath() string { func getRootCAPath() string {
return filepath.Join(config.StepPath(), "certs", "root_ca.crt") return filepath.Join(step.Path(), "certs", "root_ca.crt")
} }
func readJSON(r io.ReadCloser, v interface{}) error { func readJSON(r io.ReadCloser, v interface{}) error {
@ -1305,7 +1314,7 @@ func readJSON(r io.ReadCloser, v interface{}) error {
func readProtoJSON(r io.ReadCloser, m proto.Message) error { func readProtoJSON(r io.ReadCloser, m proto.Message) error {
defer r.Close() defer r.Close()
data, err := ioutil.ReadAll(r) data, err := io.ReadAll(r)
if err != nil { if err != nil {
return err return err
} }

View file

@ -337,8 +337,8 @@ func TestClient_Sign(t *testing.T) {
}{ }{
{"ok", request, ok, 200, false, nil}, {"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"empty request", &api.SignRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix + "force.")},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -410,7 +410,7 @@ func TestClient_Revoke(t *testing.T) {
}{ }{
{"ok", request, ok, 200, false, nil}, {"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -455,7 +455,7 @@ func TestClient_Revoke(t *testing.T) {
if got != nil { if got != nil {
t.Errorf("Client.Revoke() = %v, want nil", got) t.Errorf("Client.Revoke() = %v, want nil", got)
} }
assert.HasPrefix(t, tt.expectedErr.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tt.expectedErr.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Revoke() = %v, want %v", got, tt.response) t.Errorf("Client.Revoke() = %v, want %v", got, tt.response)
@ -484,8 +484,8 @@ func TestClient_Renew(t *testing.T) {
}{ }{
{"ok", ok, 200, false, nil}, {"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -519,7 +519,7 @@ func TestClient_Renew(t *testing.T) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response) t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
@ -553,8 +553,8 @@ func TestClient_Rekey(t *testing.T) {
}{ }{
{"ok", request, ok, 200, false, nil}, {"ok", request, ok, 200, false, nil},
{"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
{"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -588,7 +588,7 @@ func TestClient_Rekey(t *testing.T) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Renew() = %v, want %v", got, tt.response) t.Errorf("Client.Renew() = %v, want %v", got, tt.response)
@ -735,7 +735,7 @@ func TestClient_Roots(t *testing.T) {
}{ }{
{"ok", ok, 200, false, nil}, {"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, {"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"bad-request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -768,7 +768,7 @@ func TestClient_Roots(t *testing.T) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.Roots() = %v, want %v", got, tt.response) t.Errorf("Client.Roots() = %v, want %v", got, tt.response)
@ -1016,7 +1016,7 @@ func TestClient_SSHBastion(t *testing.T) {
}{ }{
{"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil}, {"ok", &api.SSHBastionRequest{Hostname: "host.local"}, ok, 200, false, nil},
{"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil}, {"bad-response", &api.SSHBastionRequest{Hostname: "host.local"}, "bad json", 200, true, nil},
{"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, {"bad-request", &api.SSHBastionRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
} }
srv := httptest.NewServer(nil) srv := httptest.NewServer(nil)
@ -1050,7 +1050,7 @@ func TestClient_SSHBastion(t *testing.T) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode) assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, tt.err.Error(), err.Error()) assert.HasPrefix(t, err.Error(), tt.err.Error())
} }
default: default:
if !reflect.DeepEqual(got, tt.response) { if !reflect.DeepEqual(got, tt.response) {

View file

@ -5,9 +5,9 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -27,21 +27,22 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL {
// $STEPPATH/config/defaults.json and the identity defined in // $STEPPATH/config/defaults.json and the identity defined in
// $STEPPATH/config/identity.json // $STEPPATH/config/identity.json
func LoadClient() (*Client, error) { func LoadClient() (*Client, error) {
b, err := ioutil.ReadFile(DefaultsFile) defaultsFile := DefaultsFile()
b, err := os.ReadFile(defaultsFile)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", DefaultsFile) return nil, errors.Wrapf(err, "error reading %s", defaultsFile)
} }
var defaults defaultsConfig var defaults defaultsConfig
if err := json.Unmarshal(b, &defaults); err != nil { if err := json.Unmarshal(b, &defaults); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", DefaultsFile) return nil, errors.Wrapf(err, "error unmarshaling %s", defaultsFile)
} }
if err := defaults.Validate(); err != nil { if err := defaults.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
caURL, err := url.Parse(defaults.CaURL) caURL, err := url.Parse(defaults.CaURL)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
if caURL.Scheme == "" { if caURL.Scheme == "" {
caURL.Scheme = "https" caURL.Scheme = "https"
@ -52,7 +53,7 @@ func LoadClient() (*Client, error) {
return nil, err return nil, err
} }
if err := identity.Validate(); err != nil { if err := identity.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", IdentityFile) return nil, errors.Wrapf(err, "error validating %s", IdentityFile())
} }
if kind := identity.Kind(); kind != MutualTLS { if kind := identity.Kind(); kind != MutualTLS {
return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind) return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind)
@ -65,7 +66,7 @@ func LoadClient() (*Client, error) {
} }
// RootCAs // RootCAs
b, err = ioutil.ReadFile(defaults.Root) b, err = os.ReadFile(defaults.Root)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error loading %s", defaults.Root) return nil, errors.Wrapf(err, "error loading %s", defaults.Root)
} }

View file

@ -3,14 +3,20 @@ package identity
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"os"
"reflect" "reflect"
"testing" "testing"
) )
func returnInput(val string) func() string {
return func() string {
return val
}
}
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
oldIdentityFile := IdentityFile oldIdentityFile := IdentityFile
oldDefaultsFile := DefaultsFile oldDefaultsFile := DefaultsFile
@ -19,8 +25,8 @@ func TestClient(t *testing.T) {
DefaultsFile = oldDefaultsFile DefaultsFile = oldDefaultsFile
}() }()
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
client, err := LoadClient() client, err := LoadClient()
if err != nil { if err != nil {
@ -40,7 +46,7 @@ func TestClient(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err := ioutil.ReadFile("testdata/certs/root_ca.crt") b, err := os.ReadFile("testdata/certs/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -114,7 +120,7 @@ func TestLoadClient(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err := ioutil.ReadFile("testdata/certs/root_ca.crt") b, err := os.ReadFile("testdata/certs/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -140,36 +146,36 @@ func TestLoadClient(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{"ok", func() { {"ok", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, expected, false}, }, expected, false},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/missing.json" IdentityFile = returnInput("testdata/config/missing.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/fail.json" IdentityFile = returnInput("testdata/config/fail.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/missing.json" DefaultsFile = returnInput("testdata/config/missing.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/fail.json" DefaultsFile = returnInput("testdata/config/fail.json")
}, nil, true}, }, nil, true},
{"fail ca", func() { {"fail ca", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badca.json" DefaultsFile = returnInput("testdata/config/badca.json")
}, nil, true}, }, nil, true},
{"fail root", func() { {"fail root", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badroot.json" DefaultsFile = returnInput("testdata/config/badroot.json")
}, nil, true}, }, nil, true},
{"fail type", func() { {"fail type", func() {
IdentityFile = "testdata/config/badIdentity.json" IdentityFile = returnInput("testdata/config/badIdentity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -7,7 +7,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"io/ioutil"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -16,7 +15,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
@ -39,11 +38,18 @@ const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims. // DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute const DefaultLeeway = 1 * time.Minute
// IdentityFile contains the location of the identity file. var (
var IdentityFile = filepath.Join(config.StepPath(), "config", "identity.json") identityDir = step.IdentityPath
configDir = step.ConfigPath
// DefaultsFile contains the location of the defaults file. // IdentityFile contains a pointer to a function that outputs the location of
var DefaultsFile = filepath.Join(config.StepPath(), "config", "defaults.json") // the identity file.
IdentityFile = step.IdentityFile
// DefaultsFile contains a prointer a function that outputs the location of the
// defaults configuration file.
DefaultsFile = step.DefaultsFile
)
// Identity represents the identity file that can be used to authenticate with // Identity represents the identity file that can be used to authenticate with
// the CA. // the CA.
@ -61,7 +67,7 @@ type Identity struct {
// LoadIdentity loads an identity present in the given filename. // LoadIdentity loads an identity present in the given filename.
func LoadIdentity(filename string) (*Identity, error) { func LoadIdentity(filename string) (*Identity, error) {
b, err := ioutil.ReadFile(filename) b, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename) return nil, errors.Wrapf(err, "error reading %s", filename)
} }
@ -74,23 +80,17 @@ func LoadIdentity(filename string) (*Identity, error) {
// LoadDefaultIdentity loads the default identity. // LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) { func LoadDefaultIdentity() (*Identity, error) {
return LoadIdentity(IdentityFile) return LoadIdentity(IdentityFile())
} }
// configDir and identityDir are used in WriteDefaultIdentity for testing
// purposes.
var (
configDir = filepath.Join(config.StepPath(), "config")
identityDir = filepath.Join(config.StepPath(), "identity")
)
// WriteDefaultIdentity writes the given certificates and key and the // WriteDefaultIdentity writes the given certificates and key and the
// identity.json pointing to the new files. // identity.json pointing to the new files.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
if err := os.MkdirAll(configDir, 0700); err != nil { if err := os.MkdirAll(configDir(), 0700); err != nil {
return errors.Wrap(err, "error creating config directory") return errors.Wrap(err, "error creating config directory")
} }
identityDir := identityDir()
if err := os.MkdirAll(identityDir, 0700); err != nil { if err := os.MkdirAll(identityDir, 0700); err != nil {
return errors.Wrap(err, "error creating identity directory") return errors.Wrap(err, "error creating identity directory")
} }
@ -112,7 +112,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
if err := pem.Encode(buf, block); err != nil { if err := pem.Encode(buf, block); err != nil {
return errors.Wrap(err, "error encoding identity key") return errors.Wrap(err, "error encoding identity key")
} }
if err := ioutil.WriteFile(keyFilename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(keyFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }
@ -127,7 +127,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
}); err != nil { }); err != nil {
return errors.Wrap(err, "error writing identity json") return errors.Wrap(err, "error writing identity json")
} }
if err := ioutil.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }
@ -136,7 +136,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
// WriteIdentityCertificate writes the identity certificate to disk. // WriteIdentityCertificate writes the identity certificate to disk.
func WriteIdentityCertificate(certChain []api.Certificate) error { func WriteIdentityCertificate(certChain []api.Certificate) error {
filename := filepath.Join(identityDir, "identity.crt") filename := filepath.Join(identityDir(), "identity.crt")
return writeCertificate(filename, certChain) return writeCertificate(filename, certChain)
} }
@ -153,7 +153,7 @@ func writeCertificate(filename string, certChain []api.Certificate) error {
} }
} }
if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(filename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing certificate") return errors.Wrap(err, "error writing certificate")
} }
@ -263,7 +263,7 @@ func (i *Identity) GetCertPool() (*x509.CertPool, error) {
if i.Root == "" { if i.Root == "" {
return nil, nil return nil, nil
} }
b, err := ioutil.ReadFile(i.Root) b, err := os.ReadFile(i.Root)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error reading identity root") return nil, errors.Wrap(err, "error reading identity root")
} }
@ -319,8 +319,8 @@ func (i *Identity) Renew(client Renewer) error {
return errors.Wrap(err, "error encoding identity certificate") return errors.Wrap(err, "error encoding identity certificate")
} }
} }
certFilename := filepath.Join(identityDir, "identity.crt") certFilename := filepath.Join(identityDir(), "identity.crt")
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }

View file

@ -33,9 +33,9 @@ func TestLoadDefaultIdentity(t *testing.T) {
want *Identity want *Identity
wantErr bool wantErr bool
}{ }{
{"ok", func() { IdentityFile = "testdata/config/identity.json" }, expected, false}, {"ok", func() { IdentityFile = returnInput("testdata/config/identity.json") }, expected, false},
{"fail read", func() { IdentityFile = "testdata/config/missing.json" }, nil, true}, {"fail read", func() { IdentityFile = returnInput("testdata/config/missing.json") }, nil, true},
{"fail unmarshal", func() { IdentityFile = "testdata/config/fail.json" }, nil, true}, {"fail unmarshal", func() { IdentityFile = returnInput("testdata/config/fail.json") }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -217,9 +217,9 @@ func TestWriteDefaultIdentity(t *testing.T) {
certChain = append(certChain, api.Certificate{Certificate: c}) certChain = append(certChain, api.Certificate{Certificate: c})
} }
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(tmpDir, "config", "identity.json") IdentityFile = returnInput(filepath.Join(tmpDir, "config", "identity.json"))
type args struct { type args struct {
certChain []api.Certificate certChain []api.Certificate
@ -233,27 +233,27 @@ func TestWriteDefaultIdentity(t *testing.T) {
}{ }{
{"ok", func() {}, args{certChain, key}, false}, {"ok", func() {}, args{certChain, key}, false},
{"fail mkdir config", func() { {"fail mkdir config", func() {
configDir = filepath.Join(tmpDir, "identity", "identity.crt") configDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail mkdir identity", func() { {"fail mkdir identity", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity", "identity.crt") identityDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail certificate", func() { {"fail certificate", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail key", func() { {"fail key", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, "badKey"}, true}, }, args{certChain, "badKey"}, true},
{"fail write identity", func() { {"fail write identity", func() {
configDir = filepath.Join(tmpDir, "bad-dir") configDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(configDir, "identity.json") IdentityFile = returnInput(filepath.Join(configDir(), "identity.json"))
os.MkdirAll(configDir, 0600) os.MkdirAll(configDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
} }
@ -377,7 +377,7 @@ func TestIdentity_Renew(t *testing.T) {
} }
oldIdentityDir := identityDir oldIdentityDir := identityDir
identityDir = "testdata/identity" identityDir = returnInput("testdata/identity")
defer func() { defer func() {
identityDir = oldIdentityDir identityDir = oldIdentityDir
os.RemoveAll(tmpDir) os.RemoveAll(tmpDir)
@ -432,8 +432,8 @@ func TestIdentity_Renew(t *testing.T) {
{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true}, {"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true}, {"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
{"fail write identity", func() { {"fail write identity", func() {
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true}, }, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -1,8 +1,8 @@
package ca package ca
import ( import (
"io/ioutil"
"net/url" "net/url"
"os"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -45,7 +45,7 @@ func TestNewProvisioner(t *testing.T) {
defer ca.Close() defer ca.Close()
want := getTestProvisioner(t, ca.URL) want := getTestProvisioner(t, ca.URL)
caBundle, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") caBundle, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -18,7 +18,7 @@ var minCertDuration = time.Minute
// TLSRenewer automatically renews a tls certificate using a RenewFunc. // TLSRenewer automatically renews a tls certificate using a RenewFunc.
type TLSRenewer struct { type TLSRenewer struct {
sync.RWMutex renewMutex sync.RWMutex
RenewCertificate RenewFunc RenewCertificate RenewFunc
cert *tls.Certificate cert *tls.Certificate
timer *time.Timer timer *time.Timer
@ -81,9 +81,9 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption
func (r *TLSRenewer) Run() { func (r *TLSRenewer) Run() {
cert := r.getCertificate() cert := r.getCertificate()
next := r.nextRenewDuration(cert.Leaf.NotAfter) next := r.nextRenewDuration(cert.Leaf.NotAfter)
r.Lock() r.renewMutex.Lock()
r.timer = time.AfterFunc(next, r.renewCertificate) r.timer = time.AfterFunc(next, r.renewCertificate)
r.Unlock() r.renewMutex.Unlock()
} }
// RunContext starts the certificate renewer for the given certificate. // RunContext starts the certificate renewer for the given certificate.
@ -133,25 +133,25 @@ func (r *TLSRenewer) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Cer
// if the timer does not fire e.g. when the CA is run from a laptop that // if the timer does not fire e.g. when the CA is run from a laptop that
// enters sleep mode. // enters sleep mode.
func (r *TLSRenewer) getCertificate() *tls.Certificate { func (r *TLSRenewer) getCertificate() *tls.Certificate {
r.RLock() r.renewMutex.RLock()
cert := r.cert cert := r.cert
r.RUnlock() r.renewMutex.RUnlock()
return cert return cert
} }
// getCertificateForCA returns the certificate using a read-only lock. It will // getCertificateForCA returns the certificate using a read-only lock. It will
// automatically renew the certificate if it has expired. // automatically renew the certificate if it has expired.
func (r *TLSRenewer) getCertificateForCA() *tls.Certificate { func (r *TLSRenewer) getCertificateForCA() *tls.Certificate {
r.RLock() r.renewMutex.RLock()
// Force certificate renewal if the timer didn't run. // Force certificate renewal if the timer didn't run.
// This is an special case that can happen after a computer sleep. // This is an special case that can happen after a computer sleep.
if time.Now().After(r.certNotAfter) { if time.Now().After(r.certNotAfter) {
r.RUnlock() r.renewMutex.RUnlock()
r.renewCertificate() r.renewCertificate()
r.RLock() r.renewMutex.RLock()
} }
cert := r.cert cert := r.cert
r.RUnlock() r.renewMutex.RUnlock()
return cert return cert
} }
@ -159,10 +159,10 @@ func (r *TLSRenewer) getCertificateForCA() *tls.Certificate {
// updates certNotAfter with 1m of delta; this will force the renewal of the // updates certNotAfter with 1m of delta; this will force the renewal of the
// certificate if it is about to expire. // certificate if it is about to expire.
func (r *TLSRenewer) setCertificate(cert *tls.Certificate) { func (r *TLSRenewer) setCertificate(cert *tls.Certificate) {
r.Lock() r.renewMutex.Lock()
r.cert = cert r.cert = cert
r.certNotAfter = cert.Leaf.NotAfter.Add(-1 * time.Minute) r.certNotAfter = cert.Leaf.NotAfter.Add(-1 * time.Minute)
r.Unlock() r.renewMutex.Unlock()
} }
func (r *TLSRenewer) renewCertificate() { func (r *TLSRenewer) renewCertificate() {
@ -175,9 +175,9 @@ func (r *TLSRenewer) renewCertificate() {
r.setCertificate(cert) r.setCertificate(cert)
next = r.nextRenewDuration(cert.Leaf.NotAfter) next = r.nextRenewDuration(cert.Leaf.NotAfter)
} }
r.Lock() r.renewMutex.Lock()
r.timer.Reset(next) r.timer.Reset(next)
r.Unlock() r.renewMutex.Unlock()
} }
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration { func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {

View file

@ -4,8 +4,8 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"os"
"reflect" "reflect"
"sort" "sort"
"testing" "testing"
@ -202,7 +202,7 @@ func TestAddRootsToRootCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -256,7 +256,7 @@ func TestAddRootsToClientCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -310,12 +310,12 @@ func TestAddFederationToRootCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -374,12 +374,12 @@ func TestAddFederationToClientCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -438,7 +438,7 @@ func TestAddRootsToCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -492,12 +492,12 @@ func TestAddFederationToCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -8,7 +8,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"io/ioutil" "io"
"log" "log"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -221,7 +221,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err) t.Fatalf("ioutil.RealAdd() error = %v", err)
} }
@ -335,7 +335,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err) t.Errorf("ioutil.RealAdd() error = %v", err)
return return
@ -374,9 +374,9 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err) t.Errorf("io.ReadAll() error = %v", err)
return return
} }
if !bytes.Equal(b, []byte("ok")) { if !bytes.Equal(b, []byte("ok")) {

View file

@ -91,7 +91,7 @@ func mustSerializeCrt(filename string, certs ...*x509.Certificate) {
panic(err) panic(err)
} }
} }
if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(filename, buf.Bytes(), 0600); err != nil {
panic(err) panic(err)
} }
} }
@ -105,7 +105,7 @@ func mustSerializeKey(filename string, key crypto.Signer) {
Type: "PRIVATE KEY", Type: "PRIVATE KEY",
Bytes: b, Bytes: b,
}) })
if err := ioutil.WriteFile(filename, b, 0600); err != nil { if err := os.WriteFile(filename, b, 0600); err != nil {
panic(err) panic(err)
} }
} }

View file

@ -21,7 +21,7 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
"go.step.sm/cli-utils/command" "go.step.sm/cli-utils/command"
"go.step.sm/cli-utils/command/version" "go.step.sm/cli-utils/command/version"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/cli-utils/usage" "go.step.sm/cli-utils/usage"
@ -49,7 +49,7 @@ var (
) )
func init() { func init() {
config.Set("Smallstep CA", Version, BuildTime) step.Set("Smallstep CA", Version, BuildTime)
authority.GlobalVersion.Version = Version authority.GlobalVersion.Version = Version
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
} }
@ -115,7 +115,7 @@ func main() {
app := cli.NewApp() app := cli.NewApp()
app.Name = "step-ca" app.Name = "step-ca"
app.HelpName = "step-ca" app.HelpName = "step-ca"
app.Version = config.Version() app.Version = step.Version()
app.Usage = "an online certificate authority for secure automated certificate management" app.Usage = "an online certificate authority for secure automated certificate management"
app.UsageText = `**step-ca** <config> [**--password-file**=<file>] app.UsageText = `**step-ca** <config> [**--password-file**=<file>]
[**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>] [**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>]

View file

@ -32,7 +32,7 @@ import (
// Config is a mapping of the cli flags. // Config is a mapping of the cli flags.
type Config struct { type Config struct {
KMS string KMS string
RootOnly bool GenerateRoot bool
RootObject string RootObject string
RootKeyObject string RootKeyObject string
RootSubject string RootSubject string
@ -58,22 +58,29 @@ func (c *Config) Validate() error {
switch { switch {
case c.KMS == "": case c.KMS == "":
return errors.New("flag `--kms` is required") return errors.New("flag `--kms` is required")
case c.CrtPath == "":
return errors.New("flag `--crt-cert-path` is required")
case c.RootFile != "" && c.KeyFile == "": case c.RootFile != "" && c.KeyFile == "":
return errors.New("flag `--root` requires flag `--key`") return errors.New("flag `--root-cert-file` requires flag `--root-key-file`")
case c.KeyFile != "" && c.RootFile == "": case c.KeyFile != "" && c.RootFile == "":
return errors.New("flag `--key` requires flag `--root`") return errors.New("flag `--root-key-file` requires flag `--root-cert-file`")
case c.RootOnly && c.RootFile != "":
return errors.New("flag `--root-only` is incompatible with flag `--root`")
case c.RootFile == "" && c.RootObject == "": case c.RootFile == "" && c.RootObject == "":
return errors.New("one of flag `--root` or `--root-cert` is required") return errors.New("one of flag `--root-cert-file` or `--root-cert-obj` is required")
case c.RootFile == "" && c.RootKeyObject == "": case c.KeyFile == "" && c.RootKeyObject == "":
return errors.New("one of flag `--root` or `--root-key` is required") return errors.New("one of flag `--root-key-file` or `--root-key-obj` is required")
case c.CrtKeyPath == "" && c.CrtKeyObject == "":
return errors.New("one of flag `--crt-key-path` or `--crt-key-obj` is required")
case c.RootFile == "" && c.GenerateRoot && c.RootKeyObject == "":
return errors.New("flag `--root-gen` requires flag `--root-key-obj`")
case c.RootFile == "" && c.GenerateRoot && c.RootPath == "":
return errors.New("flag `--root-gen` requires `--root-cert-path`")
default: default:
if c.RootFile != "" { if c.RootFile != "" {
c.GenerateRoot = false
c.RootObject = "" c.RootObject = ""
c.RootKeyObject = "" c.RootKeyObject = ""
} }
if c.RootOnly { if c.CrtKeyPath != "" {
c.CrtObject = "" c.CrtObject = ""
c.CrtKeyObject = "" c.CrtKeyObject = ""
} }
@ -101,21 +108,27 @@ func main() {
var c Config var c Config
flag.StringVar(&c.KMS, "kms", kmsuri, "PKCS #11 URI with the module-path and token to connect to the module.") flag.StringVar(&c.KMS, "kms", kmsuri, "PKCS #11 URI with the module-path and token to connect to the module.")
flag.StringVar(&c.Pin, "pin", "", "PKCS #11 PIN") flag.StringVar(&c.Pin, "pin", "", "PKCS #11 PIN")
flag.StringVar(&c.RootObject, "root-cert", "pkcs11:id=7330;object=root-cert", "PKCS #11 URI with object id and label to store the root certificate.") // Option 1: Generate new root
flag.StringVar(&c.RootPath, "root-cert-path", "root_ca.crt", "Location to write the root certificate.") flag.BoolVar(&c.GenerateRoot, "root-gen", true, "Enable the generation of a root key.")
flag.StringVar(&c.RootKeyObject, "root-key", "pkcs11:id=7330;object=root-key", "PKCS #11 URI with object id and label to store the root key.")
flag.StringVar(&c.RootSubject, "root-name", "PKCS #11 Smallstep Root", "Subject and Issuer of the root certificate.") flag.StringVar(&c.RootSubject, "root-name", "PKCS #11 Smallstep Root", "Subject and Issuer of the root certificate.")
flag.StringVar(&c.CrtObject, "crt-cert", "pkcs11:id=7331;object=intermediate-cert", "PKCS #11 URI with object id and label to store the intermediate certificate.") flag.StringVar(&c.RootObject, "root-cert-obj", "pkcs11:id=7330;object=root-cert", "PKCS #11 URI with object id and label to store the root certificate.")
flag.StringVar(&c.CrtPath, "crt-cert-path", "intermediate_ca.crt", "Location to write the intermediate certificate.") flag.StringVar(&c.RootKeyObject, "root-key-obj", "pkcs11:id=7330;object=root-key", "PKCS #11 URI with object id and label to store the root key.")
flag.StringVar(&c.CrtKeyObject, "crt-key", "pkcs11:id=7331;object=intermediate-key", "PKCS #11 URI with object id and label to store the intermediate certificate.") // Option 2: Read root from disk and sign intermediate
flag.StringVar(&c.RootFile, "root-cert-file", "", "Path to the root certificate to use.")
flag.StringVar(&c.KeyFile, "root-key-file", "", "Path to the root key to use.")
// Option 3: Generate certificate signing request
flag.StringVar(&c.CrtSubject, "crt-name", "PKCS #11 Smallstep Intermediate", "Subject of the intermediate certificate.") flag.StringVar(&c.CrtSubject, "crt-name", "PKCS #11 Smallstep Intermediate", "Subject of the intermediate certificate.")
flag.StringVar(&c.CrtKeyPath, "crt-key-path", "intermediate_ca_key", "Location to write the intermediate private key.") flag.StringVar(&c.CrtObject, "crt-cert-obj", "pkcs11:id=7331;object=intermediate-cert", "PKCS #11 URI with object id and label to store the intermediate certificate.")
flag.StringVar(&c.CrtKeyObject, "crt-key-obj", "pkcs11:id=7331;object=intermediate-key", "PKCS #11 URI with object id and label to store the intermediate certificate.")
// SSH certificates
flag.BoolVar(&c.EnableSSH, "ssh", false, "Enable the creation of ssh keys.")
flag.StringVar(&c.SSHHostKeyObject, "ssh-host-key", "pkcs11:id=7332;object=ssh-host-key", "PKCS #11 URI with object id and label to store the key used to sign SSH host certificates.") flag.StringVar(&c.SSHHostKeyObject, "ssh-host-key", "pkcs11:id=7332;object=ssh-host-key", "PKCS #11 URI with object id and label to store the key used to sign SSH host certificates.")
flag.StringVar(&c.SSHUserKeyObject, "ssh-user-key", "pkcs11:id=7333;object=ssh-user-key", "PKCS #11 URI with object id and label to store the key used to sign SSH user certificates.") flag.StringVar(&c.SSHUserKeyObject, "ssh-user-key", "pkcs11:id=7333;object=ssh-user-key", "PKCS #11 URI with object id and label to store the key used to sign SSH user certificates.")
flag.BoolVar(&c.RootOnly, "root-only", false, "Store only only the root certificate and sign and intermediate.") // Output files
flag.StringVar(&c.RootFile, "root", "", "Path to the root certificate to use.") flag.StringVar(&c.RootPath, "root-cert-path", "root_ca.crt", "Location to write the root certificate.")
flag.StringVar(&c.KeyFile, "key", "", "Path to the root key to use.") flag.StringVar(&c.CrtPath, "crt-cert-path", "intermediate_ca.crt", "Location to write the intermediate certificate.")
flag.BoolVar(&c.EnableSSH, "ssh", false, "Enable the creation of ssh keys.") flag.StringVar(&c.CrtKeyPath, "crt-key-path", "", "Location to write the intermediate private key.")
// Others
flag.BoolVar(&c.NoCerts, "no-certs", false, "Do not store certificates in the module.") flag.BoolVar(&c.NoCerts, "no-certs", false, "Do not store certificates in the module.")
flag.BoolVar(&c.Force, "force", false, "Force the delete of previous keys.") flag.BoolVar(&c.Force, "force", false, "Force the delete of previous keys.")
flag.BoolVar(&c.Extractable, "extractable", false, "Allow export of private keys under wrap.") flag.BoolVar(&c.Extractable, "extractable", false, "Allow export of private keys under wrap.")
@ -276,22 +289,8 @@ func createPKI(k kms.KeyManager, c Config) error {
// Root Certificate // Root Certificate
var signer crypto.Signer var signer crypto.Signer
var root *x509.Certificate var root *x509.Certificate
if c.RootFile != "" && c.KeyFile != "" { switch {
root, err = pemutil.ReadCertificate(c.RootFile) case c.GenerateRoot:
if err != nil {
return err
}
key, err := pemutil.Read(c.KeyFile)
if err != nil {
return err
}
var ok bool
if signer, ok = key.(crypto.Signer); !ok {
return errors.Errorf("key type '%T' does not implement a signer", key)
}
} else {
resp, err := k.CreateKey(&apiv1.CreateKeyRequest{ resp, err := k.CreateKey(&apiv1.CreateKeyRequest{
Name: c.RootKeyObject, Name: c.RootKeyObject,
SignatureAlgorithm: apiv1.ECDSAWithSHA256, SignatureAlgorithm: apiv1.ECDSAWithSHA256,
@ -331,7 +330,7 @@ func createPKI(k kms.KeyManager, c Config) error {
return errors.Wrap(err, "error parsing root certificate") return errors.Wrap(err, "error parsing root certificate")
} }
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts { if cm, ok := k.(kms.CertificateManager); ok && c.RootObject != "" && !c.NoCerts {
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{ if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.RootObject, Name: c.RootObject,
Certificate: root, Certificate: root,
@ -339,6 +338,8 @@ func createPKI(k kms.KeyManager, c Config) error {
}); err != nil { }); err != nil {
return err return err
} }
} else {
c.RootObject = ""
} }
if err := fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{
@ -350,12 +351,31 @@ func createPKI(k kms.KeyManager, c Config) error {
ui.PrintSelected("Root Key", resp.Name) ui.PrintSelected("Root Key", resp.Name)
ui.PrintSelected("Root Certificate", c.RootPath) ui.PrintSelected("Root Certificate", c.RootPath)
if c.RootObject != "" {
ui.PrintSelected("Root Certificate Object", c.RootObject)
}
case c.RootFile != "" && c.KeyFile != "": // Read Root From File
root, err = pemutil.ReadCertificate(c.RootFile)
if err != nil {
return err
}
key, err := pemutil.Read(c.KeyFile)
if err != nil {
return err
}
var ok bool
if signer, ok = key.(crypto.Signer); !ok {
return errors.Errorf("key type '%T' does not implement a signer", key)
}
} }
// Intermediate Certificate // Intermediate Certificate
var keyName string var keyName string
var publicKey crypto.PublicKey var publicKey crypto.PublicKey
if c.RootOnly { var intSigner crypto.Signer
if c.CrtKeyPath != "" {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
return errors.Wrap(err, "error creating intermediate key") return errors.Wrap(err, "error creating intermediate key")
@ -373,6 +393,7 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
publicKey = priv.Public() publicKey = priv.Public()
intSigner = priv
} else { } else {
resp, err := k.CreateKey(&apiv1.CreateKeyRequest{ resp, err := k.CreateKey(&apiv1.CreateKeyRequest{
Name: c.CrtKeyObject, Name: c.CrtKeyObject,
@ -384,56 +405,89 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
publicKey = resp.PublicKey publicKey = resp.PublicKey
keyName = resp.Name keyName = resp.Name
}
template := &x509.Certificate{ intSigner, err = k.CreateSigner(&resp.CreateSignerRequest)
IsCA: true, if err != nil {
NotBefore: now,
NotAfter: now.Add(time.Hour * 24 * 365 * 10),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
MaxPathLen: 0,
MaxPathLenZero: true,
Issuer: root.Subject,
Subject: pkix.Name{CommonName: c.CrtSubject},
SerialNumber: mustSerialNumber(),
SubjectKeyId: mustSubjectKeyID(publicKey),
}
b, err := x509.CreateCertificate(rand.Reader, template, root, publicKey, signer)
if err != nil {
return err
}
intermediate, err := x509.ParseCertificate(b)
if err != nil {
return errors.Wrap(err, "error parsing intermediate certificate")
}
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts {
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.CrtObject,
Certificate: intermediate,
Extractable: c.Extractable,
}); err != nil {
return err return err
} }
} }
if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{ if root != nil {
Type: "CERTIFICATE", template := &x509.Certificate{
Bytes: b, IsCA: true,
}), 0600); err != nil { NotBefore: now,
return err NotAfter: now.Add(time.Hour * 24 * 365 * 10),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
MaxPathLen: 0,
MaxPathLenZero: true,
Issuer: root.Subject,
Subject: pkix.Name{CommonName: c.CrtSubject},
SerialNumber: mustSerialNumber(),
SubjectKeyId: mustSubjectKeyID(publicKey),
}
b, err := x509.CreateCertificate(rand.Reader, template, root, publicKey, signer)
if err != nil {
return err
}
intermediate, err := x509.ParseCertificate(b)
if err != nil {
return errors.Wrap(err, "error parsing intermediate certificate")
}
if cm, ok := k.(kms.CertificateManager); ok && c.CrtObject != "" && !c.NoCerts {
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.CrtObject,
Certificate: intermediate,
Extractable: c.Extractable,
}); err != nil {
return err
}
} else {
c.CrtObject = ""
}
if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: b,
}), 0600); err != nil {
return err
}
} else {
// No root available, generate CSR for external root.
csrTemplate := x509.CertificateRequest{
Subject: pkix.Name{CommonName: c.CrtSubject},
SignatureAlgorithm: x509.ECDSAWithSHA256,
}
// step: generate the csr request
csrCertificate, err := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, intSigner)
if err != nil {
return err
}
if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csrCertificate,
}), 0600); err != nil {
return err
}
} }
if c.RootOnly { if c.CrtKeyPath != "" {
ui.PrintSelected("Intermediate Key", c.CrtKeyPath) ui.PrintSelected("Intermediate Key", c.CrtKeyPath)
} else { } else {
ui.PrintSelected("Intermediate Key", keyName) ui.PrintSelected("Intermediate Key", keyName)
} }
ui.PrintSelected("Intermediate Certificate", c.CrtPath) if root != nil {
ui.PrintSelected("Intermediate Certificate", c.CrtPath)
if c.CrtObject != "" {
ui.PrintSelected("Intermediate Certificate Object", c.CrtObject)
}
} else {
ui.PrintSelected("Intermediate Certificate Request", c.CrtPath)
}
if c.SSHHostKeyObject != "" { if c.SSHHostKeyObject != "" {
resp, err := k.CreateKey(&apiv1.CreateKeyRequest{ resp, err := k.CreateKey(&apiv1.CreateKeyRequest{

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -98,7 +97,7 @@ To get a linked authority token:
var password []byte var password []byte
if passFile != "" { if passFile != "" {
if password, err = ioutil.ReadFile(passFile); err != nil { if password, err = os.ReadFile(passFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", passFile)) fatal(errors.Wrapf(err, "error reading %s", passFile))
} }
password = bytes.TrimRightFunc(password, unicode.IsSpace) password = bytes.TrimRightFunc(password, unicode.IsSpace)
@ -106,7 +105,7 @@ To get a linked authority token:
var sshHostPassword []byte var sshHostPassword []byte
if sshHostPassFile != "" { if sshHostPassFile != "" {
if sshHostPassword, err = ioutil.ReadFile(sshHostPassFile); err != nil { if sshHostPassword, err = os.ReadFile(sshHostPassFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", sshHostPassFile)) fatal(errors.Wrapf(err, "error reading %s", sshHostPassFile))
} }
sshHostPassword = bytes.TrimRightFunc(sshHostPassword, unicode.IsSpace) sshHostPassword = bytes.TrimRightFunc(sshHostPassword, unicode.IsSpace)
@ -114,7 +113,7 @@ To get a linked authority token:
var sshUserPassword []byte var sshUserPassword []byte
if sshUserPassFile != "" { if sshUserPassFile != "" {
if sshUserPassword, err = ioutil.ReadFile(sshUserPassFile); err != nil { if sshUserPassword, err = os.ReadFile(sshUserPassFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", sshUserPassFile)) fatal(errors.Wrapf(err, "error reading %s", sshUserPassFile))
} }
sshUserPassword = bytes.TrimRightFunc(sshUserPassword, unicode.IsSpace) sshUserPassword = bytes.TrimRightFunc(sshUserPassword, unicode.IsSpace)
@ -122,7 +121,7 @@ To get a linked authority token:
var issuerPassword []byte var issuerPassword []byte
if issuerPassFile != "" { if issuerPassFile != "" {
if issuerPassword, err = ioutil.ReadFile(issuerPassFile); err != nil { if issuerPassword, err = os.ReadFile(issuerPassFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", issuerPassFile)) fatal(errors.Wrapf(err, "error reading %s", issuerPassFile))
} }
issuerPassword = bytes.TrimRightFunc(issuerPassword, unicode.IsSpace) issuerPassword = bytes.TrimRightFunc(issuerPassword, unicode.IsSpace)

View file

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "os"
"unicode" "unicode"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -72,14 +72,14 @@ func exportAction(ctx *cli.Context) error {
} }
if passwordFile != "" { if passwordFile != "" {
b, err := ioutil.ReadFile(passwordFile) b, err := os.ReadFile(passwordFile)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", passwordFile) return errors.Wrapf(err, "error reading %s", passwordFile)
} }
cfg.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace)) cfg.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace))
} }
if issuerPasswordFile != "" { if issuerPasswordFile != "" {
b, err := ioutil.ReadFile(issuerPasswordFile) b, err := os.ReadFile(issuerPasswordFile)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", issuerPasswordFile) return errors.Wrapf(err, "error reading %s", issuerPasswordFile)
} }

View file

@ -25,7 +25,7 @@ type Option func(e *Error) error
// message only if it is empty. // message only if it is empty.
func withDefaultMessage(format string, args ...interface{}) Option { func withDefaultMessage(format string, args ...interface{}) Option {
return func(e *Error) error { return func(e *Error) error {
if len(e.Msg) > 0 { if e.Msg != "" {
return e return e
} }
e.Msg = fmt.Sprintf(format, args...) e.Msg = fmt.Sprintf(format, args...)
@ -164,7 +164,8 @@ type Messenger interface {
func StatusCodeError(code int, e error, opts ...Option) error { func StatusCodeError(code int, e error, opts ...Option) error {
switch code { switch code {
case http.StatusBadRequest: case http.StatusBadRequest:
return BadRequestErr(e, opts...) opts = append(opts, withDefaultMessage(BadRequestDefaultMsg))
return NewErr(http.StatusBadRequest, e, opts...)
case http.StatusUnauthorized: case http.StatusUnauthorized:
return UnauthorizedErr(e, opts...) return UnauthorizedErr(e, opts...)
case http.StatusForbidden: case http.StatusForbidden:
@ -194,6 +195,21 @@ var (
NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs NotImplementedDefaultMsg = "The requested method is not implemented by the certificate authority. " + seeLogs
) )
var (
// BadRequestPrefix is the prefix added to the bad request messages that are
// directly sent to the cli.
BadRequestPrefix = "The request could not be completed: "
)
func formatMessage(status int, msg string) string {
switch status {
case http.StatusBadRequest:
return BadRequestPrefix + msg + "."
default:
return msg
}
}
// splitOptionArgs splits the variadic length args into string formatting args // splitOptionArgs splits the variadic length args into string formatting args
// and Option(s) to apply to an Error. // and Option(s) to apply to an Error.
func splitOptionArgs(args []interface{}) ([]interface{}, []Option) { func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
@ -218,6 +234,32 @@ func splitOptionArgs(args []interface{}) ([]interface{}, []Option) {
return args[:indexOptionStart], opts return args[:indexOptionStart], opts
} }
// New creates a new http error with the given status and message.
func New(status int, format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
return &Error{
Status: status,
Msg: formatMessage(status, msg),
Err: errors.New(msg),
}
}
// NewError creates a new http error with the given error and message.
func NewError(status int, err error, format string, args ...interface{}) error {
if _, ok := err.(*Error); ok {
return err
}
msg := fmt.Sprintf(format, args...)
if _, ok := err.(StackTracer); !ok {
err = errors.Wrap(err, msg)
}
return &Error{
Status: status,
Msg: formatMessage(status, msg),
Err: err,
}
}
// NewErr returns a new Error. If the given error implements the StatusCoder // NewErr returns a new Error. If the given error implements the StatusCoder
// interface we will ignore the given status. // interface we will ignore the given status.
func NewErr(status int, err error, opts ...Option) error { func NewErr(status int, err error, opts ...Option) error {
@ -254,6 +296,18 @@ func Errorf(code int, format string, args ...interface{}) error {
return e return e
} }
// ApplyOptions applies the given options to the error if is the type *Error.
// TODO(mariano): try to get rid of this.
func ApplyOptions(err error, opts ...interface{}) error {
if e, ok := err.(*Error); ok {
_, o := splitOptionArgs(opts)
for _, fn := range o {
fn(e)
}
}
return err
}
// InternalServer creates a 500 error with the given format and arguments. // InternalServer creates a 500 error with the given format and arguments.
func InternalServer(format string, args ...interface{}) error { func InternalServer(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg)) args = append(args, withDefaultMessage(InternalServerErrorDefaultMsg))
@ -280,14 +334,12 @@ func NotImplementedErr(err error, opts ...Option) error {
// BadRequest creates a 400 error with the given format and arguments. // BadRequest creates a 400 error with the given format and arguments.
func BadRequest(format string, args ...interface{}) error { func BadRequest(format string, args ...interface{}) error {
args = append(args, withDefaultMessage(BadRequestDefaultMsg)) return New(http.StatusBadRequest, format, args...)
return Errorf(http.StatusBadRequest, format, args...)
} }
// BadRequestErr returns an 400 error with the given error. // BadRequestErr returns an 400 error with the given error.
func BadRequestErr(err error, opts ...Option) error { func BadRequestErr(err error, format string, args ...interface{}) error {
opts = append(opts, withDefaultMessage(BadRequestDefaultMsg)) return NewError(http.StatusBadRequest, err, format, args...)
return NewErr(http.StatusBadRequest, err, opts...)
} }
// Unauthorized creates a 401 error with the given format and arguments. // Unauthorized creates a 401 error with the given format and arguments.

View file

@ -3,7 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil" "io"
"os" "os"
"time" "time"
@ -32,7 +32,7 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -3,7 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil" "io"
"os" "os"
"time" "time"
@ -32,7 +32,7 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -2,7 +2,6 @@ all: binaries build up
binaries: binaries:
CGO_ENABLED=0 GOOS=linux go build -o ca/step-ca github.com/smallstep/certificates/cmd/step-ca CGO_ENABLED=0 GOOS=linux go build -o ca/step-ca github.com/smallstep/certificates/cmd/step-ca
CGO_ENABLED=0 GOOS=linux go build -o renewer/step github.com/smallstep/cli/cmd/step
build: build-nginx build-ca build-renewer build: build-nginx build-ca build-renewer
build-nginx: build-nginx:

View file

@ -1,7 +1,7 @@
FROM alpine:latest FROM smallstep/step-cli
USER root
RUN mkdir -p /var/local/step RUN mkdir -p /var/local/step
ADD step /usr/local/bin/step
ADD crontab /var/spool/cron/crontabs/root ADD crontab /var/spool/cron/crontabs/root
RUN chmod 0644 /var/spool/cron/crontabs/root RUN chmod 0644 /var/spool/cron/crontabs/root

17
go.mod
View file

@ -1,6 +1,6 @@
module github.com/smallstep/certificates module github.com/smallstep/certificates
go 1.15 go 1.16
require ( require (
cloud.google.com/go v0.83.0 cloud.google.com/go v0.83.0
@ -29,10 +29,10 @@ require (
github.com/rs/xid v1.2.1 github.com/rs/xid v1.2.1
github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus v1.4.2
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262
github.com/smallstep/nosql v0.3.8 github.com/smallstep/nosql v0.3.9
github.com/urfave/cli v1.22.4 github.com/urfave/cli v1.22.4
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352
go.step.sm/cli-utils v0.6.1 go.step.sm/cli-utils v0.7.0
go.step.sm/crypto v0.13.0 go.step.sm/crypto v0.13.0
go.step.sm/linkedca v0.8.0 go.step.sm/linkedca v0.8.0
golang.org/x/crypto v0.0.0-20210915214749-c084706c2272 golang.org/x/crypto v0.0.0-20210915214749-c084706c2272
@ -44,11 +44,8 @@ require (
gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/square/go-jose.v2 v2.6.0
) )
// avoid license conflict from juju/ansiterm until https://github.com/manifoldco/promptui/pull/181 //replace github.com/smallstep/nosql => ../nosql
// is merged or other dependency in path currently in violation fixes compliance
replace github.com/manifoldco/promptui => github.com/nguyer/promptui v0.8.1-0.20210517132806-70ccd4709797
// replace github.com/smallstep/nosql => ../nosql //replace go.step.sm/crypto => ../crypto
// replace go.step.sm/crypto => ../crypto
// replace go.step.sm/cli-utils => ../cli-utils //replace go.step.sm/cli-utils => ../cli-utils
// replace go.step.sm/linkedca => ../linkedca

20
go.sum
View file

@ -367,16 +367,15 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
github.com/lunixbochs/vtclean v0.0.0-20180621232353-2d01aacdc34a/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ=
github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA=
github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg=
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8=
@ -426,8 +425,6 @@ github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxzi
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
github.com/newrelic/go-agent v2.15.0+incompatible h1:IB0Fy+dClpBq9aEoIrLyQXzU34JyI1xVTanPLB/+jvU= github.com/newrelic/go-agent v2.15.0+incompatible h1:IB0Fy+dClpBq9aEoIrLyQXzU34JyI1xVTanPLB/+jvU=
github.com/newrelic/go-agent v2.15.0+incompatible/go.mod h1:a8Fv1b/fYhFSReoTU6HDkTYIMZeSVNffmoS726Y0LzQ= github.com/newrelic/go-agent v2.15.0+incompatible/go.mod h1:a8Fv1b/fYhFSReoTU6HDkTYIMZeSVNffmoS726Y0LzQ=
github.com/nguyer/promptui v0.8.1-0.20210517132806-70ccd4709797 h1:unCiBzwNjcuVbP3bgM76z0ORyIuI4sspop1qhkQJ044=
github.com/nguyer/promptui v0.8.1-0.20210517132806-70ccd4709797/go.mod h1:CBMXL3a2sC3Q8TjpLcQt8w/3aQ23VSy6r7UFeCG6phA=
github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs= github.com/oklog/oklog v0.3.2/go.mod h1:FCV+B7mhrz4o+ueLpx+KqkyXRGMWOYEvfiXtdGtbWGs=
github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA=
github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo=
@ -497,8 +494,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd
github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE= github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc=
github.com/smallstep/nosql v0.3.8 h1:1/EWUbbEdz9ai0g9Fd09VekVjtxp+5+gIHpV2PdwW3o= github.com/smallstep/nosql v0.3.9 h1:YPy5PR3PXClqmpFaVv0wfXDXDc7NXGBE1auyU2c87dc=
github.com/smallstep/nosql v0.3.8/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs= github.com/smallstep/nosql v0.3.9/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
@ -562,8 +559,8 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.step.sm/cli-utils v0.6.1 h1:v31ctEh/BFPGU067fF9Y8u2EIg6LRldUbN2dc/+u/V8= go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds=
go.step.sm/cli-utils v0.6.1/go.mod h1:stgyXHHHi9KwcR86sgzDdFC6e/tAmpF4NbqwSK7q/GM= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E=
go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0=
go.step.sm/crypto v0.13.0 h1:mQuP9Uu2FNmqCJNO0OTbvolnYXzONy4wdUBtUVcP1s8= go.step.sm/crypto v0.13.0 h1:mQuP9Uu2FNmqCJNO0OTbvolnYXzONy4wdUBtUVcP1s8=
go.step.sm/crypto v0.13.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg= go.step.sm/crypto v0.13.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg=
@ -737,7 +734,6 @@ golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200828194041-157a740278f4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -754,8 +750,9 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164 h1:7ZDGnxgHAMw7thfC5bEos0RDAccZKxioiWBhfIe+tvw=
golang.org/x/sys v0.0.0-20210915083310-ed5796bab164/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210915083310-ed5796bab164/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211031064116-611d5d643895 h1:iaNpwpnrgL5jzWS0vCNnfa8HqzxveCFpFx3uC/X4Tps=
golang.org/x/sys v0.0.0-20211031064116-611d5d643895/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -957,7 +954,6 @@ google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"crypto" "crypto"
"fmt" "fmt"
"io/ioutil" "os"
"reflect" "reflect"
"testing" "testing"
@ -165,7 +165,7 @@ func TestCloudKMS_Close(t *testing.T) {
func TestCloudKMS_CreateSigner(t *testing.T) { func TestCloudKMS_CreateSigner(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1" keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -223,7 +223,7 @@ func TestCloudKMS_CreateKey(t *testing.T) {
testError := fmt.Errorf("an error") testError := fmt.Errorf("an error")
alreadyExists := status.Error(codes.AlreadyExists, "already exists") alreadyExists := status.Error(codes.AlreadyExists, "already exists")
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -389,7 +389,7 @@ func TestCloudKMS_GetPublicKey(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1" keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
testError := fmt.Errorf("an error") testError := fmt.Errorf("an error")
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -7,7 +7,7 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io" "io"
"io/ioutil" "os"
"reflect" "reflect"
"testing" "testing"
@ -17,7 +17,7 @@ import (
) )
func Test_newSigner(t *testing.T) { func Test_newSigner(t *testing.T) {
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -70,7 +70,7 @@ func Test_newSigner(t *testing.T) {
} }
func Test_signer_Public(t *testing.T) { func Test_signer_Public(t *testing.T) {
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -159,7 +159,7 @@ func Test_signer_Sign(t *testing.T) {
} }
func TestSigner_SignatureAlgorithm(t *testing.T) { func TestSigner_SignatureAlgorithm(t *testing.T) {
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -11,7 +11,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "os"
"reflect" "reflect"
"testing" "testing"
@ -78,7 +78,7 @@ func TestSoftKMS_CreateSigner(t *testing.T) {
} }
// Read and decode file using standard packages // Read and decode file using standard packages
b, err := ioutil.ReadFile("testdata/priv.pem") b, err := os.ReadFile("testdata/priv.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -234,7 +234,7 @@ func TestSoftKMS_CreateKey(t *testing.T) {
} }
func TestSoftKMS_GetPublicKey(t *testing.T) { func TestSoftKMS_GetPublicKey(t *testing.T) {
b, err := ioutil.ReadFile("testdata/pub.pem") b, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -332,7 +332,7 @@ func TestSoftKMS_CreateDecrypter(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err := ioutil.ReadFile("testdata/rsa.priv.pem") b, err := os.ReadFile("testdata/rsa.priv.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -9,7 +9,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"io/ioutil"
"net" "net"
"os" "os"
"os/exec" "os/exec"
@ -202,7 +201,7 @@ func TestNew(t *testing.T) {
}) })
// Load ssh test fixtures // Load ssh test fixtures
b, err := ioutil.ReadFile("testdata/ssh") b, err := os.ReadFile("testdata/ssh")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -290,7 +289,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) {
} }
// Read and decode file using standard packages // Read and decode file using standard packages
b, err := ioutil.ReadFile("testdata/priv.pem") b, err := os.ReadFile("testdata/priv.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -315,7 +314,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) {
}) })
// Load ssh test fixtures // Load ssh test fixtures
sshPubKeyStr, err := ioutil.ReadFile("testdata/ssh.pub") sshPubKeyStr, err := os.ReadFile("testdata/ssh.pub")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -323,7 +322,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err = ioutil.ReadFile("testdata/ssh") b, err = os.ReadFile("testdata/ssh")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -499,7 +498,7 @@ func TestSSHAgentKMS_CreateKey(t *testing.T) {
*/ */
func TestSSHAgentKMS_GetPublicKey(t *testing.T) { func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
b, err := ioutil.ReadFile("testdata/pub.pem") b, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -510,7 +509,7 @@ func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
} }
// Load ssh test fixtures // Load ssh test fixtures
b, err = ioutil.ReadFile("testdata/ssh.pub") b, err = os.ReadFile("testdata/ssh.pub")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -518,7 +517,7 @@ func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err = ioutil.ReadFile("testdata/ssh") b, err = os.ReadFile("testdata/ssh")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -3,8 +3,8 @@ package uri
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"io/ioutil"
"net/url" "net/url"
"os"
"strings" "strings"
"unicode" "unicode"
@ -140,7 +140,7 @@ func readFile(path string) ([]byte, error) {
if err == nil && (u.Scheme == "" || u.Scheme == "file") && u.Path != "" { if err == nil && (u.Scheme == "" || u.Scheme == "file") && u.Path != "" {
path = u.Path path = u.Path
} }
b, err := ioutil.ReadFile(path) b, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", path) return nil, errors.Wrapf(err, "error reading %s", path)
} }

View file

@ -29,9 +29,9 @@ import (
"github.com/smallstep/certificates/kms" "github.com/smallstep/certificates/kms"
kmsapi "github.com/smallstep/certificates/kms/apiv1" kmsapi "github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/errs"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
@ -87,44 +87,50 @@ const (
) )
// GetDBPath returns the path where the file-system persistence is stored // GetDBPath returns the path where the file-system persistence is stored
// based on the STEPPATH environment variable. // based on the $(step path).
func GetDBPath() string { func GetDBPath() string {
return filepath.Join(config.StepPath(), dbPath) return filepath.Join(step.Path(), dbPath)
} }
// GetConfigPath returns the directory where the configuration files are stored // GetConfigPath returns the directory where the configuration files are stored
// based on the STEPPATH environment variable. // based on the $(step path).
func GetConfigPath() string { func GetConfigPath() string {
return filepath.Join(config.StepPath(), configPath) return filepath.Join(step.Path(), configPath)
}
// GetProfileConfigPath returns the directory where the profile configuration
// files are stored based on the $(step path).
func GetProfileConfigPath() string {
return filepath.Join(step.ProfilePath(), configPath)
} }
// GetPublicPath returns the directory where the public keys are stored based on // GetPublicPath returns the directory where the public keys are stored based on
// the STEPPATH environment variable. // the $(step path).
func GetPublicPath() string { func GetPublicPath() string {
return filepath.Join(config.StepPath(), publicPath) return filepath.Join(step.Path(), publicPath)
} }
// GetSecretsPath returns the directory where the private keys are stored based // GetSecretsPath returns the directory where the private keys are stored based
// on the STEPPATH environment variable. // on the $(step path).
func GetSecretsPath() string { func GetSecretsPath() string {
return filepath.Join(config.StepPath(), privatePath) return filepath.Join(step.Path(), privatePath)
} }
// GetRootCAPath returns the path where the root CA is stored based on the // GetRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable. // $(step path).
func GetRootCAPath() string { func GetRootCAPath() string {
return filepath.Join(config.StepPath(), publicPath, "root_ca.crt") return filepath.Join(step.Path(), publicPath, "root_ca.crt")
} }
// GetOTTKeyPath returns the path where the one-time token key is stored based // GetOTTKeyPath returns the path where the one-time token key is stored based
// on the STEPPATH environment variable. // on the $(step path).
func GetOTTKeyPath() string { func GetOTTKeyPath() string {
return filepath.Join(config.StepPath(), privatePath, "ott_key") return filepath.Join(step.Path(), privatePath, "ott_key")
} }
// GetTemplatesPath returns the path where the templates are stored. // GetTemplatesPath returns the path where the templates are stored.
func GetTemplatesPath() string { func GetTemplatesPath() string {
return filepath.Join(config.StepPath(), templatesPath) return filepath.Join(step.Path(), templatesPath)
} }
// GetProvisioners returns the map of provisioners on the given CA. // GetProvisioners returns the map of provisioners on the given CA.
@ -286,20 +292,22 @@ func WithKeyURIs(rootKey, intermediateKey, hostKey, userKey string) Option {
// PKI represents the Public Key Infrastructure used by a certificate authority. // PKI represents the Public Key Infrastructure used by a certificate authority.
type PKI struct { type PKI struct {
linkedca.Configuration linkedca.Configuration
Defaults linkedca.Defaults Defaults linkedca.Defaults
casOptions apiv1.Options casOptions apiv1.Options
caService apiv1.CertificateAuthorityService caService apiv1.CertificateAuthorityService
caCreator apiv1.CertificateAuthorityCreator caCreator apiv1.CertificateAuthorityCreator
keyManager kmsapi.KeyManager keyManager kmsapi.KeyManager
config string config string
defaults string defaults string
ottPublicKey *jose.JSONWebKey profileDefaults string
ottPrivateKey *jose.JSONWebEncryption ottPublicKey *jose.JSONWebKey
options *options ottPrivateKey *jose.JSONWebEncryption
options *options
} }
// New creates a new PKI configuration. // New creates a new PKI configuration.
func New(o apiv1.Options, opts ...Option) (*PKI, error) { func New(o apiv1.Options, opts ...Option) (*PKI, error) {
currentCtx := step.Contexts().GetCurrent()
caService, err := cas.New(context.Background(), o) caService, err := cas.New(context.Background(), o)
if err != nil { if err != nil {
return nil, err return nil, err
@ -358,6 +366,9 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
cfg = GetConfigPath() cfg = GetConfigPath()
// Create directories // Create directories
dirs := []string{public, private, cfg, GetTemplatesPath()} dirs := []string{public, private, cfg, GetTemplatesPath()}
if currentCtx != nil {
dirs = append(dirs, GetProfileConfigPath())
}
for _, name := range dirs { for _, name := range dirs {
if _, err := os.Stat(name); os.IsNotExist(err) { if _, err := os.Stat(name); os.IsNotExist(err) {
if err = os.MkdirAll(name, 0700); err != nil { if err = os.MkdirAll(name, 0700); err != nil {
@ -415,6 +426,10 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
if p.defaults, err = getPath(cfg, "defaults.json"); err != nil { if p.defaults, err = getPath(cfg, "defaults.json"); err != nil {
return nil, err return nil, err
} }
if currentCtx != nil {
p.profileDefaults = currentCtx.ProfileDefaultsFile()
}
if p.config, err = getPath(cfg, "ca.json"); err != nil { if p.config, err = getPath(cfg, "ca.json"); err != nil {
return nil, err return nil, err
} }
@ -944,6 +959,18 @@ func (p *PKI) Save(opt ...ConfigOption) error {
if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil { if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil {
return errs.FileError(err, p.defaults) return errs.FileError(err, p.defaults)
} }
// If we're using contexts then write a blank object to the default profile
// configuration location.
if p.profileDefaults != "" {
if _, err := os.Stat(p.profileDefaults); os.IsNotExist(err) {
// Write with 0600 to be consistent with directories structure.
if err = fileutil.WriteFile(p.profileDefaults, []byte("{}"), 0600); err != nil {
return errs.FileError(err, p.profileDefaults)
}
} else if err != nil {
return errs.FileError(err, p.profileDefaults)
}
}
// Generate and write templates // Generate and write templates
if err := generateTemplates(cfg.Templates); err != nil { if err := generateTemplates(cfg.Templates); err != nil {
@ -958,6 +985,9 @@ func (p *PKI) Save(opt ...ConfigOption) error {
} }
ui.PrintSelected("Default configuration", p.defaults) ui.PrintSelected("Default configuration", p.defaults)
if p.profileDefaults != "" {
ui.PrintSelected("Default profile configuration", p.profileDefaults)
}
ui.PrintSelected("Certificate Authority configuration", p.config) ui.PrintSelected("Certificate Authority configuration", p.config)
if p.options.deploymentType != LinkedDeployment { if p.options.deploymentType != LinkedDeployment {
ui.Println() ui.Println()

View file

@ -6,9 +6,9 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/errs"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
) )
// getTemplates returns all the templates enabled // getTemplates returns all the templates enabled
@ -44,7 +44,7 @@ func generateTemplates(t *templates.Templates) error {
if !ok { if !ok {
return errors.Errorf("template %s does not exists", t.Name) return errors.Errorf("template %s does not exists", t.Name)
} }
if err := fileutil.WriteFile(config.StepAbs(t.TemplatePath), []byte(data), 0644); err != nil { if err := fileutil.WriteFile(step.Abs(t.TemplatePath), []byte(data), 0644); err != nil {
return err return err
} }
} }
@ -53,7 +53,7 @@ func generateTemplates(t *templates.Templates) error {
if !ok { if !ok {
return errors.Errorf("template %s does not exists", t.Name) return errors.Errorf("template %s does not exists", t.Name)
} }
if err := fileutil.WriteFile(config.StepAbs(t.TemplatePath), []byte(data), 0644); err != nil { if err := fileutil.WriteFile(step.Abs(t.TemplatePath), []byte(data), 0644); err != nil {
return err return err
} }
} }

View file

@ -5,7 +5,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -167,7 +166,7 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
return SCEPRequest{}, errors.Errorf("unsupported operation: %s", operation) return SCEPRequest{}, errors.Errorf("unsupported operation: %s", operation)
} }
case http.MethodPost: case http.MethodPost:
body, err := ioutil.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize))
if err != nil { if err != nil {
return SCEPRequest{}, err return SCEPRequest{}, err
} }

View file

@ -2,4 +2,4 @@
For documentation on `step-ca.service`, see [Running `step-ca` As A Daemon](https://smallstep.com/docs/step-ca/certificate-authority-server-production#running-step-ca-as-a-daemon). For documentation on `step-ca.service`, see [Running `step-ca` As A Daemon](https://smallstep.com/docs/step-ca/certificate-authority-server-production#running-step-ca-as-a-daemon).
For documentation on `cert-renewer@.*`, see [Automating Certificate Renewal](https://smallstep.com/docs/step-ca/certificate-authority-server-production#automate-x509-certificate-lifecycle-management) See also: There is a systemd certificate renewal timer, in the [`systemd` directory of `smallstep/cli`](https://github.com/smallstep/cli/tree/master/systemd).

View file

@ -1,28 +0,0 @@
[Unit]
Description=Certificate renewer for %I
After=network-online.target
Documentation=https://smallstep.com/docs/step-ca/certificate-authority-server-production
StartLimitIntervalSec=0
[Service]
Type=oneshot
User=root
Environment=STEPPATH=/etc/step-ca \
CERT_LOCATION=/etc/step/certs/%i.crt \
KEY_LOCATION=/etc/step/certs/%i.key
; ExecCondition checks if the certificate is ready for renewal,
; based on the exit status of the command.
; (In systemd 242 or below, you can use ExecStartPre= here.)
ExecCondition=/usr/bin/step certificate needs-renewal ${CERT_LOCATION}
; ExecStart renews the certificate, if ExecStartPre was successful.
ExecStart=/usr/bin/step ca renew --force ${CERT_LOCATION} ${KEY_LOCATION}
; Try to reload or restart the systemd service that relies on this cert-renewer
; If the relying service doesn't exist, forge ahead.
ExecStartPost=/usr/bin/env bash -c "if ! systemctl --quiet is-enabled %i.service ; then exit 0; fi; systemctl try-reload-or-restart %i"
[Install]
WantedBy=multi-user.target

View file

@ -1,18 +0,0 @@
[Unit]
Description=Certificate renewal timer for %I
Documentation=https://smallstep.com/docs/step-ca/certificate-authority-server-production
[Timer]
Persistent=true
; Run the timer unit every 5 minutes.
OnCalendar=*:1/5
; Always run the timer on time.
AccuracySec=1us
; Add jitter to prevent a "thundering hurd" of simultaneous certificate renewals.
RandomizedDelaySec=5m
[Install]
WantedBy=timers.target

View file

@ -2,7 +2,6 @@ package templates
import ( import (
"bytes" "bytes"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -10,8 +9,8 @@ import (
"github.com/Masterminds/sprig/v3" "github.com/Masterminds/sprig/v3"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
) )
// TemplateType defines how a template will be written in disk. // TemplateType defines how a template will be written in disk.
@ -20,6 +19,9 @@ type TemplateType string
const ( const (
// Snippet will mark a template as a part of a file. // Snippet will mark a template as a part of a file.
Snippet TemplateType = "snippet" Snippet TemplateType = "snippet"
// PrependLine is a template for prepending a single line to a file. If the
// line already exists in the file it will be removed first.
PrependLine TemplateType = "prepend-line"
// File will mark a templates as a full file. // File will mark a templates as a full file.
File TemplateType = "file" File TemplateType = "file"
// Directory will mark a template as a directory. // Directory will mark a template as a directory.
@ -99,7 +101,7 @@ func (t *SSHTemplates) Validate() (err error) {
return return
} }
// Template represents on template file. // Template represents a template file.
type Template struct { type Template struct {
*template.Template *template.Template
Name string `json:"name"` Name string `json:"name"`
@ -118,8 +120,8 @@ func (t *Template) Validate() error {
return nil return nil
case t.Name == "": case t.Name == "":
return errors.New("template name cannot be empty") return errors.New("template name cannot be empty")
case t.Type != Snippet && t.Type != File && t.Type != Directory: case t.Type != Snippet && t.Type != File && t.Type != Directory && t.Type != PrependLine:
return errors.Errorf("invalid template type %s, it must be %s, %s, or %s", t.Type, Snippet, File, Directory) return errors.Errorf("invalid template type %s, it must be %s, %s, %s, or %s", t.Type, Snippet, PrependLine, File, Directory)
case t.TemplatePath == "" && t.Type != Directory && len(t.Content) == 0: case t.TemplatePath == "" && t.Type != Directory && len(t.Content) == 0:
return errors.New("template template cannot be empty") return errors.New("template template cannot be empty")
case t.TemplatePath != "" && t.Type == Directory: case t.TemplatePath != "" && t.Type == Directory:
@ -132,7 +134,7 @@ func (t *Template) Validate() error {
if t.TemplatePath != "" { if t.TemplatePath != "" {
// Check for file // Check for file
st, err := os.Stat(config.StepAbs(t.TemplatePath)) st, err := os.Stat(step.Abs(t.TemplatePath))
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", t.TemplatePath) return errors.Wrapf(err, "error reading %s", t.TemplatePath)
} }
@ -166,8 +168,8 @@ func (t *Template) Load() error {
if t.Template == nil && t.Type != Directory { if t.Template == nil && t.Type != Directory {
switch { switch {
case t.TemplatePath != "": case t.TemplatePath != "":
filename := config.StepAbs(t.TemplatePath) filename := step.Abs(t.TemplatePath)
b, err := ioutil.ReadFile(filename) b, err := os.ReadFile(filename)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", filename) return errors.Wrapf(err, "error reading %s", filename)
} }
@ -247,7 +249,10 @@ type Output struct {
// Write writes the Output to the filesystem as a directory, file or snippet. // Write writes the Output to the filesystem as a directory, file or snippet.
func (o *Output) Write() error { func (o *Output) Write() error {
path := config.StepAbs(o.Path) // Replace ${STEPPATH} with the base step path.
o.Path = strings.ReplaceAll(o.Path, "${STEPPATH}", step.BasePath())
path := step.Abs(o.Path)
if o.Type == Directory { if o.Type == Directory {
return mkdir(path, 0700) return mkdir(path, 0700)
} }
@ -257,11 +262,17 @@ func (o *Output) Write() error {
return err return err
} }
if o.Type == File { switch o.Type {
case File:
return fileutil.WriteFile(path, o.Content, 0600) return fileutil.WriteFile(path, o.Content, 0600)
case Snippet:
return fileutil.WriteSnippet(path, o.Content, 0600)
case PrependLine:
return fileutil.PrependLine(path, o.Content, 0600)
default:
// Default to using a Snippet type if the type is not known.
return fileutil.WriteSnippet(path, o.Content, 0600)
} }
return fileutil.WriteSnippet(path, o.Content, 0600)
} }
func mkdir(path string, perm os.FileMode) error { func mkdir(path string, perm os.FileMode) error {

View file

@ -4,6 +4,10 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// SSHTemplateVersionKey is a key that can be submitted by a client to select
// the template version that will be returned by the server.
var SSHTemplateVersionKey = "StepSSHTemplateVersion"
// Step represents the default variables available in the CA. // Step represents the default variables available in the CA.
type Step struct { type Step struct {
SSH StepSSH SSH StepSSH
@ -22,16 +26,23 @@ type StepSSH struct {
var DefaultSSHTemplates = SSHTemplates{ var DefaultSSHTemplates = SSHTemplates{
User: []Template{ User: []Template{
{ {
Name: "include.tpl", Name: "config.tpl",
Type: Snippet, Type: Snippet,
TemplatePath: "templates/ssh/include.tpl", TemplatePath: "templates/ssh/config.tpl",
Path: "~/.ssh/config", Path: "~/.ssh/config",
Comment: "#", Comment: "#",
}, },
{ {
Name: "config.tpl", Name: "step_includes.tpl",
Type: PrependLine,
TemplatePath: "templates/ssh/step_includes.tpl",
Path: "${STEPPATH}/ssh/includes",
Comment: "#",
},
{
Name: "step_config.tpl",
Type: File, Type: File,
TemplatePath: "templates/ssh/config.tpl", TemplatePath: "templates/ssh/step_config.tpl",
Path: "ssh/config", Path: "ssh/config",
Comment: "#", Comment: "#",
}, },
@ -64,30 +75,43 @@ var DefaultSSHTemplates = SSHTemplates{
// DefaultSSHTemplateData contains the data of the default templates used on ssh. // DefaultSSHTemplateData contains the data of the default templates used on ssh.
var DefaultSSHTemplateData = map[string]string{ var DefaultSSHTemplateData = map[string]string{
// include.tpl adds the step ssh config file. // config.tpl adds the step ssh config file.
// //
// Note: on windows `Include C:\...` is treated as a relative path. // Note: on windows `Include C:\...` is treated as a relative path.
"include.tpl": `Host * "config.tpl": `Host *
{{- if or .User.GOOS "none" | eq "windows" }} {{- if or .User.GOOS "none" | eq "windows" }}
Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config" {{- if .User.StepBasePath }}
Include "{{ .User.StepBasePath | replace "\\" "/" | trimPrefix "C:" }}/ssh/includes"
{{- else }} {{- else }}
Include "{{.User.StepPath}}/ssh/config" Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/includes"
{{- end }}
{{- else }}
{{- if .User.StepBasePath }}
Include "{{.User.StepBasePath}}/ssh/includes"
{{- else }}
Include "{{.User.StepPath}}/ssh/includes"
{{- end }}
{{- end }}`, {{- end }}`,
// config.tpl is the step ssh config file, it includes the Match rule and // step_includes.tpl adds the step ssh config file.
//
// Note: on windows `Include C:\...` is treated as a relative path.
"step_includes.tpl": `{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}`,
// step_config.tpl is the step ssh config file, it includes the Match rule and
// references the step known_hosts file. // references the step known_hosts file.
// //
// Note: on windows ProxyCommand requires the full path // Note: on windows ProxyCommand requires the full path
"config.tpl": `Match exec "step ssh check-host %h" "step_config.tpl": `Match exec "step ssh check-host{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %h"
{{- if .User.User }} {{- if .User.User }}
User {{.User.User}} User {{.User.User}}
{{- end }} {{- end }}
{{- if or .User.GOOS "none" | eq "windows" }} {{- if or .User.GOOS "none" | eq "windows" }}
UserKnownHostsFile "{{.User.StepPath}}\ssh\known_hosts" UserKnownHostsFile "{{.User.StepPath}}\ssh\known_hosts"
ProxyCommand C:\Windows\System32\cmd.exe /c step ssh proxycommand %r %h %p ProxyCommand C:\Windows\System32\cmd.exe /c step ssh proxycommand{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %r %h %p
{{- else }} {{- else }}
UserKnownHostsFile "{{.User.StepPath}}/ssh/known_hosts" UserKnownHostsFile "{{.User.StepPath}}/ssh/known_hosts"
ProxyCommand step ssh proxycommand %r %h %p ProxyCommand step ssh proxycommand{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %r %h %p
{{- end }} {{- end }}
`, `,