Merge branch 'master' into hs/acme-revocation

This commit is contained in:
Herman Slatman 2021-11-19 17:00:18 +01:00
commit 2d50c96d99
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
62 changed files with 643 additions and 459 deletions

View file

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: [ '1.15', '1.16', '1.17' ] go: [ '1.16', '1.17' ]
steps: steps:
- -
name: Checkout name: Checkout
@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.41.0' version: 'v1.43.0'
# Optional: working directory, useful for monorepos # Optional: working directory, useful for monorepos
# working-directory: somedir # working-directory: somedir
@ -58,6 +58,7 @@ jobs:
run: V=1 make ci run: V=1 make ci
- -
name: Codecov name: Codecov
if: matrix.go == '1.17'
uses: codecov/codecov-action@v1.2.1 uses: codecov/codecov-action@v1.2.1
with: with:
file: ./coverage.out # optional file: ./coverage.out # optional

View file

@ -73,9 +73,3 @@ issues:
- error strings should not be capitalized or end with punctuation or a newline - error strings should not be capitalized or end with punctuation or a newline
- Wrapf call needs 1 arg but has 2 args - Wrapf call needs 1 arg but has 2 args
- cs.NegotiatedProtocolIsMutual is deprecated - cs.NegotiatedProtocolIsMutual is deprecated
# golangci.com configuration
# https://github.com/golangci/golangci/wiki/Configuration
service:
golangci-lint-version: 1.19.x # use the fixed version to not introduce new linters unexpectedly
prepare:
- echo "here I can run custom commands, but no preparation needed for this repo"

View file

@ -4,15 +4,23 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased - 0.17.7] - DATE ## [Unreleased - 0.18.1] - DATE
### Added ### Added
- Support for generate extractable keys and certificates on a pkcs#11 module.
### Changed ### Changed
### Deprecated ### Deprecated
### Removed ### Removed
### Fixed ### Fixed
### Security ### Security
## [0.18.0] - 2021-11-17
### Added
- Support for multiple certificate authority contexts.
- Support for generating extractable keys and certificates on a pkcs#11 module.
### Changed
- Support two latest versions of golang (1.16, 1.17)
### Deprecated
- go 1.15 support
## [0.17.6] - 2021-10-20 ## [0.17.6] - 2021-10-20
### Notes ### Notes
- 0.17.5 failed in CI/CD - 0.17.5 failed in CI/CD

View file

@ -5,7 +5,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
@ -263,7 +263,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -468,7 +468,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -668,7 +668,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -17,7 +17,7 @@ import (
) )
func link(url, typ string) string { func link(url, typ string) string {
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ) return fmt.Sprintf("<%s>;rel=%q", url, typ)
} }
// Clock that returns time in UTC rounded to seconds. // Clock that returns time in UTC rounded to seconds.

View file

@ -7,7 +7,7 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -89,7 +89,7 @@ func TestHandler_GetDirectory(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -261,7 +261,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -404,7 +404,7 @@ func TestHandler_GetCertificate(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -660,7 +660,7 @@ func TestHandler_GetChallenge(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"crypto/rsa" "crypto/rsa"
"errors" "errors"
"io/ioutil" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -118,7 +118,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct. // parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
func (h *Handler) parseJWS(next nextHTTP) nextHTTP { func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body"))
return return
@ -413,7 +413,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
} }
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
value: payload, value: payload,
isPostAsGet: string(payload) == "", isPostAsGet: len(payload) == 0,
isEmptyJSON: string(payload) == "{}", isEmptyJSON: string(payload) == "{}",
}) })
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))

View file

@ -8,7 +8,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -148,7 +147,7 @@ func TestHandler_addNonce(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -205,7 +204,7 @@ func TestHandler_addDirLink(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -332,7 +331,7 @@ func TestHandler_verifyContentType(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -400,7 +399,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -490,7 +489,7 @@ func TestHandler_parseJWS(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -689,7 +688,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -891,7 +890,7 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1087,7 +1086,7 @@ func TestHandler_extractJWK(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1454,7 +1453,7 @@ func TestHandler_validateJWS(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -7,7 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect" "reflect"
@ -430,7 +430,7 @@ func TestHandler_GetOrder(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1343,7 +1343,7 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1633,7 +1633,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -12,7 +12,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -92,7 +92,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
"error doing http GET for url %s with status code %d", u, resp.StatusCode)) "error doing http GET for url %s with status code %d", u, resp.StatusCode))
} }
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return WrapErrorISE(err, "error reading "+ return WrapErrorISE(err, "error reading "+
"response body for url %s", u) "response body for url %s", u)

View file

@ -15,7 +15,6 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math/big" "math/big"
"net" "net"
"net/http" "net/http"
@ -707,7 +706,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
}, },
}, },
@ -733,7 +732,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
}, },
}, },
@ -775,7 +774,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
}, },
}, },
@ -818,7 +817,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
}, },
}, },
@ -860,7 +859,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
}, },
}, },

View file

@ -16,7 +16,7 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io"
"math/big" "math/big"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -788,7 +788,7 @@ func Test_caHandler_Health(t *testing.T) {
t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode) t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Health unexpected error = %v", err) t.Errorf("caHandler.Health unexpected error = %v", err)
@ -829,7 +829,7 @@ func Test_caHandler_Root(t *testing.T) {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err) t.Errorf("caHandler.Root unexpected error = %v", err)
@ -902,7 +902,7 @@ func Test_caHandler_Sign(t *testing.T) {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err) t.Errorf("caHandler.Root unexpected error = %v", err)
@ -954,7 +954,7 @@ func Test_caHandler_Renew(t *testing.T) {
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Renew unexpected error = %v", err) t.Errorf("caHandler.Renew unexpected error = %v", err)
@ -1015,7 +1015,7 @@ func Test_caHandler_Rekey(t *testing.T) {
t.Errorf("caHandler.Rekey StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Rekey StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Rekey unexpected error = %v", err) t.Errorf("caHandler.Rekey unexpected error = %v", err)
@ -1038,12 +1038,12 @@ func Test_caHandler_Provisioners(t *testing.T) {
r *http.Request r *http.Request
} }
req, err := http.NewRequest("GET", "http://example.com/provisioners?cursor=foo&limit=20", nil) req, err := http.NewRequest("GET", "http://example.com/provisioners?cursor=foo&limit=20", http.NoBody)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
reqLimitFail, err := http.NewRequest("GET", "http://example.com/provisioners?limit=abc", nil) reqLimitFail, err := http.NewRequest("GET", "http://example.com/provisioners?limit=abc", http.NoBody)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1105,7 +1105,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Provisioners unexpected error = %v", err) t.Errorf("caHandler.Provisioners unexpected error = %v", err)
@ -1175,7 +1175,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Provisioners unexpected error = %v", err) t.Errorf("caHandler.Provisioners unexpected error = %v", err)
@ -1225,7 +1225,7 @@ func Test_caHandler_Roots(t *testing.T) {
t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Roots unexpected error = %v", err) t.Errorf("caHandler.Roots unexpected error = %v", err)
@ -1271,7 +1271,7 @@ func Test_caHandler_Federation(t *testing.T) {
t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Federation unexpected error = %v", err) t.Errorf("caHandler.Federation unexpected error = %v", err)

View file

@ -6,7 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -233,7 +233,7 @@ func Test_caHandler_Revoke(t *testing.T) {
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -10,7 +10,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
@ -299,14 +299,14 @@ func Test_caHandler_SSHSign(t *testing.T) {
body []byte body []byte
statusCode int statusCode int
}{ }{
{"ok-user", userReq, nil, user, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated}, {"ok-user", userReq, nil, user, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q}`, userB64)), http.StatusCreated},
{"ok-host", hostReq, nil, host, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated}, {"ok-host", hostReq, nil, host, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q}`, hostB64)), http.StatusCreated},
{"ok-user-add", userAddReq, nil, user, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated}, {"ok-user-add", userAddReq, nil, user, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q,"addUserCrt":%q}`, userB64, userB64)), http.StatusCreated},
{"ok-user-identity", userIdentityReq, nil, user, nil, user, nil, identityCerts, nil, []byte(fmt.Sprintf(`{"crt":"%s","identityCrt":[%s]}`, userB64, identityCertsPEM)), http.StatusCreated}, {"ok-user-identity", userIdentityReq, nil, user, nil, user, nil, identityCerts, nil, []byte(fmt.Sprintf(`{"crt":%q,"identityCrt":[%s]}`, userB64, identityCertsPEM)), http.StatusCreated},
{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":%q,"ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, nil, nil, http.StatusUnauthorized}, {"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, nil, nil, http.StatusUnauthorized},
{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusForbidden}, {"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusForbidden},
{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden}, {"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
@ -338,7 +338,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
t.Errorf("caHandler.SignSSH StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SignSSH StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SignSSH unexpected error = %v", err) t.Errorf("caHandler.SignSSH unexpected error = %v", err)
@ -368,10 +368,10 @@ func Test_caHandler_SSHRoots(t *testing.T) {
body []byte body []byte
statusCode int statusCode int
}{ }{
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK}, {"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q],"hostKey":[%q]}`, userB64, hostB64)), http.StatusOK},
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK}, {"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q,%q],"hostKey":[%q,%q]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK}, {"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q]}`, userB64)), http.StatusOK},
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK}, {"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":[%q]}`, hostB64)), http.StatusOK},
{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound}, {"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
} }
@ -392,7 +392,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
t.Errorf("caHandler.SSHRoots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHRoots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHRoots unexpected error = %v", err) t.Errorf("caHandler.SSHRoots unexpected error = %v", err)
@ -422,10 +422,10 @@ func Test_caHandler_SSHFederation(t *testing.T) {
body []byte body []byte
statusCode int statusCode int
}{ }{
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK}, {"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q],"hostKey":[%q]}`, userB64, hostB64)), http.StatusOK},
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK}, {"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q,%q],"hostKey":[%q,%q]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK}, {"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q]}`, userB64)), http.StatusOK},
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK}, {"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":[%q]}`, hostB64)), http.StatusOK},
{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound}, {"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
} }
@ -446,7 +446,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
t.Errorf("caHandler.SSHFederation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHFederation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHFederation unexpected error = %v", err) t.Errorf("caHandler.SSHFederation unexpected error = %v", err)
@ -506,7 +506,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
t.Errorf("caHandler.SSHConfig StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHConfig StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHConfig unexpected error = %v", err) t.Errorf("caHandler.SSHConfig unexpected error = %v", err)
@ -553,7 +553,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
t.Errorf("caHandler.SSHCheckHost StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHCheckHost StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHCheckHost unexpected error = %v", err) t.Errorf("caHandler.SSHCheckHost unexpected error = %v", err)
@ -604,7 +604,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
t.Errorf("caHandler.SSHGetHosts StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHGetHosts StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHGetHosts unexpected error = %v", err) t.Errorf("caHandler.SSHGetHosts unexpected error = %v", err)
@ -659,7 +659,7 @@ func Test_caHandler_SSHBastion(t *testing.T) {
t.Errorf("caHandler.SSHBastion StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHBastion StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHBastion unexpected error = %v", err) t.Errorf("caHandler.SSHBastion unexpected error = %v", err)

View file

@ -3,7 +3,6 @@ package api
import ( import (
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil"
"log" "log"
"net/http" "net/http"
@ -102,7 +101,7 @@ func ReadJSON(r io.Reader, v interface{}) error {
// ReadProtoJSON reads JSON from the request body and stores it in the value // ReadProtoJSON reads JSON from the request body and stores it in the value
// pointed by v. // pointed by v.
func ReadProtoJSON(r io.Reader, m proto.Message) error { func ReadProtoJSON(r io.Reader, m proto.Message) error {
data, err := ioutil.ReadAll(r) data, err := io.ReadAll(r)
if err != nil { if err != nil {
return errs.Wrap(http.StatusBadRequest, err, "error reading request body") return errs.Wrap(http.StatusBadRequest, err, "error reading request body")
} }

View file

@ -7,8 +7,8 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -195,7 +195,7 @@ func TestAuthority_GetDatabase(t *testing.T) {
} }
func TestNewEmbedded(t *testing.T) { func TestNewEmbedded(t *testing.T) {
caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") caPEM, err := os.ReadFile("testdata/certs/root_ca.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt")
@ -268,7 +268,7 @@ func TestNewEmbedded(t *testing.T) {
} }
func TestNewEmbedded_Sign(t *testing.T) { func TestNewEmbedded_Sign(t *testing.T) {
caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") caPEM, err := os.ReadFile("testdata/certs/root_ca.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt")
@ -294,7 +294,7 @@ func TestNewEmbedded_Sign(t *testing.T) {
} }
func TestNewEmbedded_GetTLSCertificate(t *testing.T) { func TestNewEmbedded_GetTLSCertificate(t *testing.T) {
caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") caPEM, err := os.ReadFile("testdata/certs/root_ca.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt")

View file

@ -2,14 +2,14 @@ package authority
import ( import (
"encoding/json" "encoding/json"
"io/ioutil"
"net/url" "net/url"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
) )
@ -245,7 +245,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
return "" return ""
} }
stepPath := filepath.ToSlash(config.StepPath()) stepPath := filepath.ToSlash(step.Path())
if !strings.HasSuffix(stepPath, "/") { if !strings.HasSuffix(stepPath, "/") {
stepPath += "/" stepPath += "/"
} }
@ -257,7 +257,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
panic(err) panic(err)
} }
if ok { if ok {
b, err := ioutil.ReadFile(config.StepAbs(fn)) b, err := os.ReadFile(step.Abs(fn))
if err != nil { if err != nil {
panic(errors.Wrapf(err, "error reading %s", fn)) panic(errors.Wrapf(err, "error reading %s", fn))
} }

View file

@ -9,9 +9,10 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io"
"net" "net"
"net/http" "net/http"
"os"
"strings" "strings"
"time" "time"
@ -165,7 +166,7 @@ func newAWSConfig(certPath string) (*awsConfig, error) {
if certPath == "" { if certPath == "" {
certBytes = []byte(awsCertificate) certBytes = []byte(awsCertificate)
} else { } else {
if b, err := ioutil.ReadFile(certPath); err == nil { if b, err := os.ReadFile(certPath); err == nil {
certBytes = b certBytes = b
} else { } else {
return nil, errors.Wrapf(err, "error reading %s", certPath) return nil, errors.Wrapf(err, "error reading %s", certPath)
@ -569,7 +570,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
client := http.Client{} client := http.Client{}
// first get the token // first get the token
req, err := http.NewRequest(http.MethodPut, p.config.tokenURL, nil) req, err := http.NewRequest(http.MethodPut, p.config.tokenURL, http.NoBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -582,7 +583,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode) return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode)
} }
token, err := ioutil.ReadAll(resp.Body) token, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -602,7 +603,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
func (p *AWS) readResponseBody(resp *http.Response) ([]byte, error) { func (p *AWS) readResponseBody(resp *http.Response) ([]byte, error) {
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,7 +6,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"io/ioutil" "io"
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
@ -173,7 +173,7 @@ func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error reading identity token response") return "", errors.Wrap(err, "error reading identity token response")
} }

View file

@ -107,7 +107,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
w.Write([]byte(t1)) w.Write([]byte(t1))
default: default:
w.Header().Add("Content-Type", "application/json") w.Header().Add("Content-Type", "application/json")
w.Write([]byte(fmt.Sprintf(`{"access_token":"%s"}`, t1))) fmt.Fprintf(w, `{"access_token":"%s"}`, t1)
} }
})) }))
defer srv.Close() defer srv.Close()

View file

@ -7,7 +7,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -183,7 +183,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
return "", errors.Wrap(err, "error doing identity request, are you in a GCP VM?") return "", errors.Wrap(err, "error doing identity request, are you in a GCP VM?")
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error on identity request") return "", errors.Wrap(err, "error on identity request")
} }

View file

@ -8,12 +8,15 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"encoding/json" "encoding/json"
"fmt"
"net" "net"
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
) )
@ -369,8 +372,19 @@ func newValidityValidator(min, max time.Duration) *validityValidator {
return &validityValidator{min: min, max: max} return &validityValidator{min: min, max: max}
} }
// TODO(mariano): refactor errs package to allow sending real errors to the
// user.
func badRequest(format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
return &errs.Error{
Status: http.StatusBadRequest,
Msg: msg,
Err: errors.New(msg),
}
}
// Valid validates the certificate validity settings (notBefore/notAfter) and // Valid validates the certificate validity settings (notBefore/notAfter) and
// and total duration. // total duration.
func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
var ( var (
na = cert.NotAfter.Truncate(time.Second) na = cert.NotAfter.Truncate(time.Second)
@ -381,22 +395,20 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
d := na.Sub(nb) d := na.Sub(nb)
if na.Before(now) { if na.Before(now) {
return errors.Errorf("notAfter cannot be in the past; na=%v", na) return badRequest("notAfter cannot be in the past; na=%v", na)
} }
if na.Before(nb) { if na.Before(nb) {
return errors.Errorf("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) return badRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
} }
if d < v.min { if d < v.min {
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v", return badRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min)
d, v.min)
} }
// NOTE: this check is not "technically correct". We're allowing the max // NOTE: this check is not "technically correct". We're allowing the max
// duration of a cert to be "max + backdate" and not all certificates will // duration of a cert to be "max + backdate" and not all certificates will
// be backdated (e.g. if a user passes the NotBefore value then we do not // be backdated (e.g. if a user passes the NotBefore value then we do not
// apply a backdate). This is good enough. // apply a backdate). This is good enough.
if d > v.max+o.Backdate { if d > v.max+o.Backdate {
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v", return badRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate)
d, v.max+o.Backdate)
} }
return nil return nil
} }

View file

@ -335,11 +335,11 @@ type sshCertValidityValidator struct {
func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error { func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error {
switch { switch {
case cert.ValidAfter == 0: case cert.ValidAfter == 0:
return errors.New("ssh certificate validAfter cannot be 0") return badRequest("ssh certificate validAfter cannot be 0")
case cert.ValidBefore < uint64(now().Unix()): case cert.ValidBefore < uint64(now().Unix()):
return errors.New("ssh certificate validBefore cannot be in the past") return badRequest("ssh certificate validBefore cannot be in the past")
case cert.ValidBefore < cert.ValidAfter: case cert.ValidBefore < cert.ValidAfter:
return errors.New("ssh certificate validBefore cannot be before validAfter") return badRequest("ssh certificate validBefore cannot be before validAfter")
} }
var min, max time.Duration var min, max time.Duration
@ -351,9 +351,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
min = v.MinHostSSHCertDuration() min = v.MinHostSSHCertDuration()
max = v.MaxHostSSHCertDuration() max = v.MaxHostSSHCertDuration()
case 0: case 0:
return errors.New("ssh certificate type has not been set") return badRequest("ssh certificate type has not been set")
default: default:
return errors.Errorf("unknown ssh certificate type %d", cert.CertType) return badRequest("unknown ssh certificate type %d", cert.CertType)
} }
// To not take into account the backdate, time.Now() will be used to // To not take into account the backdate, time.Now() will be used to
@ -362,11 +362,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
switch { switch {
case dur < min: case dur < min:
return errors.Errorf("requested duration of %s is less than minimum "+ return badRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min)
"accepted duration for selected provisioner of %s", dur, min)
case dur > max+opts.Backdate: case dur > max+opts.Backdate:
return errors.Errorf("requested duration of %s is greater than maximum "+ return badRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
"accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
default: default:
return nil return nil
} }

View file

@ -10,9 +10,9 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"strings" "strings"
"time" "time"
@ -188,7 +188,7 @@ func generateJWK() (*JWK, error) {
} }
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
fooPubB, err := ioutil.ReadFile("./testdata/certs/foo.pub") fooPubB, err := os.ReadFile("./testdata/certs/foo.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -196,7 +196,7 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
barPubB, err := ioutil.ReadFile("./testdata/certs/bar.pub") barPubB, err := os.ReadFile("./testdata/certs/bar.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -234,7 +234,7 @@ func generateSSHPOP() (*SSHPOP, error) {
return nil, err return nil, err
} }
userB, err := ioutil.ReadFile("./testdata/certs/ssh_user_ca_key.pub") userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -242,7 +242,7 @@ func generateSSHPOP() (*SSHPOP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
hostB, err := ioutil.ReadFile("./testdata/certs/ssh_host_ca_key.pub") hostB, err := os.ReadFile("./testdata/certs/ssh_host_ca_key.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,14 +6,14 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
step "go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/linkedca" "go.step.sm/linkedca"
@ -238,6 +238,8 @@ func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error {
return nil return nil
} }
// CreateFirstProvisioner creates and stores the first provisioner when using
// admin database provisioner storage.
func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) { func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) {
if password == "" { if password == "" {
pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one") pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one")
@ -287,6 +289,7 @@ func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (
return p, nil return p, nil
} }
// ValidateClaims validates the Claims type.
func ValidateClaims(c *linkedca.Claims) error { func ValidateClaims(c *linkedca.Claims) error {
if c == nil { if c == nil {
return nil return nil
@ -313,6 +316,7 @@ func ValidateClaims(c *linkedca.Claims) error {
return nil return nil
} }
// ValidateDurations validates the Durations type.
func ValidateDurations(d *linkedca.Durations) error { func ValidateDurations(d *linkedca.Durations) error {
var ( var (
err error err error
@ -523,8 +527,8 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.X509.Template != "" { if p.X509.Template != "" {
x509Template.Template = []byte(p.SSH.Template) x509Template.Template = []byte(p.SSH.Template)
} else if p.X509.TemplateFile != "" { } else if p.X509.TemplateFile != "" {
filename := step.StepAbs(p.X509.TemplateFile) filename := step.Abs(p.X509.TemplateFile)
if x509Template.Template, err = ioutil.ReadFile(filename); err != nil { if x509Template.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading x509 template") return nil, nil, errors.Wrap(err, "error reading x509 template")
} }
} }
@ -539,8 +543,8 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.SSH.Template != "" { if p.SSH.Template != "" {
sshTemplate.Template = []byte(p.SSH.Template) sshTemplate.Template = []byte(p.SSH.Template)
} else if p.SSH.TemplateFile != "" { } else if p.SSH.TemplateFile != "" {
filename := step.StepAbs(p.SSH.TemplateFile) filename := step.Abs(p.SSH.TemplateFile)
if sshTemplate.Template, err = ioutil.ReadFile(filename); err != nil { if sshTemplate.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading ssh template") return nil, nil, errors.Wrap(err, "error reading ssh template")
} }
} }

View file

@ -101,6 +101,15 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Backwards compatibility for version of the cli older than v0.18.0.
// Before v0.18.0 we were not passing any value for SSHTemplateVersionKey
// from the cli.
if o.Name == "step_includes.tpl" && data[templates.SSHTemplateVersionKey] == "" {
o.Type = templates.File
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/")
}
output = append(output, o) output = append(output, o)
} }
return output, nil return output, nil

View file

@ -501,6 +501,32 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")}, {Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")},
} }
tmplConfigUserIncludes := &templates.Templates{
SSH: &templates.SSHTemplates{
User: []templates.Template{
{Name: "step_includes.tpl", Type: templates.PrependLine, TemplatePath: "./testdata/templates/step_includes.tpl", Path: "${STEPPATH}/ssh/includes", Comment: "#"},
},
},
Data: map[string]interface{}{
"Step": &templates.Step{
SSH: templates.StepSSH{
UserKey: user,
HostKey: host,
},
},
},
}
userOutputEmptyData := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"<no value>/ssh/config\"\n")},
}
userOutputWithoutTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
userOutputWithTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.PrependLine, Comment: "#", Path: "${STEPPATH}/ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
tmplConfigErr := &templates.Templates{ tmplConfigErr := &templates.Templates{
SSH: &templates.SSHTemplates{ SSH: &templates.SSHTemplates{
User: []templates.Template{ User: []templates.Template{
@ -542,6 +568,9 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false}, {"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false},
{"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false}, {"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false},
{"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false}, {"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false},
{"userIncludesEmptyData", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", nil}, userOutputEmptyData, false},
{"userIncludesWithoutTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithoutTemplateVersion, false},
{"userIncludesWithTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step", "StepSSHTemplateVersion": "v2"}}, userOutputWithTemplateVersion, false},
{"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true}, {"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true},
{"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true}, {"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true},
{"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true}, {"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true},

View file

@ -0,0 +1 @@
{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}

View file

@ -310,7 +310,7 @@ func TestAuthority_Sign(t *testing.T) {
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: _signOpts, signOpts: _signOpts,
err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"), err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"),
code: http.StatusUnauthorized, code: http.StatusBadRequest,
} }
}, },
"fail validate sans when adding common name not in claims": func(t *testing.T) *signTest { "fail validate sans when adding common name not in claims": func(t *testing.T) *signTest {
@ -538,15 +538,15 @@ ZYtQ9Ot36qc=
if tc.csr.Subject.CommonName == "" { if tc.csr.Subject.CommonName == "" {
assert.Equals(t, leaf.Subject, pkix.Name{}) assert.Equals(t, leaf.Subject, pkix.Name{})
} else { } else {
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization}, Organization: []string{tmplt.Organization},
Locality: []string{tmplt.Locality}, Locality: []string{tmplt.Locality},
StreetAddress: []string{tmplt.StreetAddress}, StreetAddress: []string{tmplt.StreetAddress},
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: "smallstep test", CommonName: "smallstep test",
})) }.String())
assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"}) assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"})
} }
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
@ -718,15 +718,15 @@ func TestAuthority_Renew(t *testing.T) {
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization}, Organization: []string{tmplt.Organization},
Locality: []string{tmplt.Locality}, Locality: []string{tmplt.Locality},
StreetAddress: []string{tmplt.StreetAddress}, StreetAddress: []string{tmplt.StreetAddress},
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: tmplt.CommonName, CommonName: tmplt.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
@ -925,15 +925,15 @@ func TestAuthority_Rekey(t *testing.T) {
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization}, Organization: []string{tmplt.Organization},
Locality: []string{tmplt.Locality}, Locality: []string{tmplt.Locality},
StreetAddress: []string{tmplt.StreetAddress}, StreetAddress: []string{tmplt.StreetAddress},
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: tmplt.CommonName, CommonName: tmplt.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)

View file

@ -7,7 +7,6 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"strings" "strings"
@ -292,7 +291,7 @@ func (c *ACMEClient) GetCertificate(url string) (*x509.Certificate, []*x509.Cert
return nil, nil, readACMEError(resp.Body) return nil, nil, readACMEError(resp.Body)
} }
defer resp.Body.Close() defer resp.Body.Close()
bodyBytes, err := ioutil.ReadAll(resp.Body) bodyBytes, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "error reading GET certificate response") return nil, nil, errors.Wrap(err, "error reading GET certificate response")
} }
@ -338,7 +337,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) {
func readACMEError(r io.ReadCloser) error { func readACMEError(r io.ReadCloser) error {
defer r.Close() defer r.Close()
b, err := ioutil.ReadAll(r) b, err := io.ReadAll(r)
if err != nil { if err != nil {
return errors.Wrap(err, "error reading from body") return errors.Wrap(err, "error reading from body")
} }

View file

@ -5,7 +5,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -317,7 +317,7 @@ func TestACMEClient_post(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -455,7 +455,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -575,7 +575,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -695,7 +695,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -815,7 +815,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -936,7 +936,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1061,7 +1061,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1188,7 +1188,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1317,7 +1317,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -197,7 +197,7 @@ func (c *AdminClient) GetAdminsPaginate(opts ...AdminOption) (*adminAPI.GetAdmin
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error generating admin token") return nil, errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create GET %s request failed", u) return nil, errors.Wrapf(err, "create GET %s request failed", u)
} }
@ -284,7 +284,7 @@ func (c *AdminClient) RemoveAdmin(id string) error {
if err != nil { if err != nil {
return errors.Wrapf(err, "error generating admin token") return errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("DELETE", u.String(), nil) req, err := http.NewRequest("DELETE", u.String(), http.NoBody)
if err != nil { if err != nil {
return errors.Wrapf(err, "create DELETE %s request failed", u) return errors.Wrapf(err, "create DELETE %s request failed", u)
} }
@ -363,7 +363,7 @@ func (c *AdminClient) GetProvisioner(opts ...ProvisionerOption) (*linkedca.Provi
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error generating admin token") return nil, errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create PUT %s request failed", u) return nil, errors.Wrapf(err, "create PUT %s request failed", u)
} }
@ -402,7 +402,7 @@ func (c *AdminClient) GetProvisionersPaginate(opts ...ProvisionerOption) (*admin
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error generating admin token") return nil, errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create PUT %s request failed", u) return nil, errors.Wrapf(err, "create PUT %s request failed", u)
} }
@ -472,7 +472,7 @@ func (c *AdminClient) RemoveProvisioner(opts ...ProvisionerOption) error {
if err != nil { if err != nil {
return errors.Wrapf(err, "error generating admin token") return errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("DELETE", u.String(), nil) req, err := http.NewRequest("DELETE", u.String(), http.NoBody)
if err != nil { if err != nil {
return errors.Wrapf(err, "create DELETE %s request failed", u) return errors.Wrapf(err, "create DELETE %s request failed", u)
} }

View file

@ -3,7 +3,7 @@ package ca
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"io/ioutil" "io"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -382,7 +382,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
return errors.Wrapf(err, "client.Get(%s) failed", srvURL) return errors.Wrapf(err, "client.Get(%s) failed", srvURL)
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errors.Wrap(err, "client.Get() error reading response") return errors.Wrap(err, "client.Get() error reading response")
} }
@ -499,7 +499,7 @@ func TestBootstrapClientServerFederation(t *testing.T) {
return errors.Wrapf(err, "client.Get(%s) failed", srvURL) return errors.Wrapf(err, "client.Get(%s) failed", srvURL)
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errors.Wrap(err, "client.Get() error reading response") return errors.Wrap(err, "client.Get() error reading response")
} }
@ -589,9 +589,9 @@ func TestBootstrapListener(t *testing.T) {
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Errorf("ioutil.ReadAll() error = %v", err) t.Errorf("io.ReadAll() error = %v", err)
return return
} }
if string(b) != "ok" { if string(b) != "ok" {

View file

@ -294,15 +294,15 @@ ZEp7knvU2psWRw==
assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second)) assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second))
assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second)) assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second))
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{asn1dn.Country}, Country: []string{asn1dn.Country},
Organization: []string{asn1dn.Organization}, Organization: []string{asn1dn.Organization},
Locality: []string{asn1dn.Locality}, Locality: []string{asn1dn.Locality},
StreetAddress: []string{asn1dn.StreetAddress}, StreetAddress: []string{asn1dn.StreetAddress},
Province: []string{asn1dn.Province}, Province: []string{asn1dn.Province},
CommonName: asn1dn.CommonName, CommonName: asn1dn.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
@ -641,10 +641,10 @@ func TestCARenew(t *testing.T) {
assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second)) assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second))
assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second)) assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second))
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
CommonName: asn1dn.CommonName, CommonName: asn1dn.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)

View file

@ -15,7 +15,6 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -29,7 +28,7 @@ import (
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/identity" "github.com/smallstep/certificates/ca/identity"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
@ -75,7 +74,7 @@ func (c *uaClient) SetTransport(tr http.RoundTripper) {
} }
func (c *uaClient) Get(u string) (*http.Response, error) { func (c *uaClient) Get(u string) (*http.Response, error) {
req, err := http.NewRequest("GET", u, nil) req, err := http.NewRequest("GET", u, http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "new request GET %s failed", u) return nil, errors.Wrapf(err, "new request GET %s failed", u)
} }
@ -238,6 +237,17 @@ func WithTransport(tr http.RoundTripper) ClientOption {
} }
} }
// WithInsecure adds a insecure transport that bypasses TLS verification.
func WithInsecure() ClientOption {
return func(o *clientOptions) error {
o.transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
return nil
}
}
// WithRootFile will create the transport using the given root certificate. It // WithRootFile will create the transport using the given root certificate. It
// will fail if a previous option to create the transport has been configured. // will fail if a previous option to create the transport has been configured.
func WithRootFile(filename string) ClientOption { func WithRootFile(filename string) ClientOption {
@ -350,7 +360,7 @@ func WithRetryFunc(fn RetryFunc) ClientOption {
} }
func getTransportFromFile(filename string) (http.RoundTripper, error) { func getTransportFromFile(filename string) (http.RoundTripper, error) {
data, err := ioutil.ReadFile(filename) data, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename) return nil, errors.Wrapf(err, "error reading %s", filename)
} }
@ -1295,7 +1305,7 @@ func createCertificateRequest(commonName string, sans []string, key crypto.Priva
// getRootCAPath returns the path where the root CA is stored based on the // getRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable. // STEPPATH environment variable.
func getRootCAPath() string { func getRootCAPath() string {
return filepath.Join(config.StepPath(), "certs", "root_ca.crt") return filepath.Join(step.Path(), "certs", "root_ca.crt")
} }
func readJSON(r io.ReadCloser, v interface{}) error { func readJSON(r io.ReadCloser, v interface{}) error {
@ -1305,7 +1315,7 @@ func readJSON(r io.ReadCloser, v interface{}) error {
func readProtoJSON(r io.ReadCloser, m proto.Message) error { func readProtoJSON(r io.ReadCloser, m proto.Message) error {
defer r.Close() defer r.Close()
data, err := ioutil.ReadAll(r) data, err := io.ReadAll(r)
if err != nil { if err != nil {
return err return err
} }

View file

@ -5,9 +5,9 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -27,21 +27,22 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL {
// $STEPPATH/config/defaults.json and the identity defined in // $STEPPATH/config/defaults.json and the identity defined in
// $STEPPATH/config/identity.json // $STEPPATH/config/identity.json
func LoadClient() (*Client, error) { func LoadClient() (*Client, error) {
b, err := ioutil.ReadFile(DefaultsFile) defaultsFile := DefaultsFile()
b, err := os.ReadFile(defaultsFile)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", DefaultsFile) return nil, errors.Wrapf(err, "error reading %s", defaultsFile)
} }
var defaults defaultsConfig var defaults defaultsConfig
if err := json.Unmarshal(b, &defaults); err != nil { if err := json.Unmarshal(b, &defaults); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", DefaultsFile) return nil, errors.Wrapf(err, "error unmarshaling %s", defaultsFile)
} }
if err := defaults.Validate(); err != nil { if err := defaults.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
caURL, err := url.Parse(defaults.CaURL) caURL, err := url.Parse(defaults.CaURL)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
if caURL.Scheme == "" { if caURL.Scheme == "" {
caURL.Scheme = "https" caURL.Scheme = "https"
@ -52,7 +53,7 @@ func LoadClient() (*Client, error) {
return nil, err return nil, err
} }
if err := identity.Validate(); err != nil { if err := identity.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", IdentityFile) return nil, errors.Wrapf(err, "error validating %s", IdentityFile())
} }
if kind := identity.Kind(); kind != MutualTLS { if kind := identity.Kind(); kind != MutualTLS {
return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind) return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind)
@ -65,7 +66,7 @@ func LoadClient() (*Client, error) {
} }
// RootCAs // RootCAs
b, err = ioutil.ReadFile(defaults.Root) b, err = os.ReadFile(defaults.Root)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error loading %s", defaults.Root) return nil, errors.Wrapf(err, "error loading %s", defaults.Root)
} }

View file

@ -3,14 +3,20 @@ package identity
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"os"
"reflect" "reflect"
"testing" "testing"
) )
func returnInput(val string) func() string {
return func() string {
return val
}
}
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
oldIdentityFile := IdentityFile oldIdentityFile := IdentityFile
oldDefaultsFile := DefaultsFile oldDefaultsFile := DefaultsFile
@ -19,8 +25,8 @@ func TestClient(t *testing.T) {
DefaultsFile = oldDefaultsFile DefaultsFile = oldDefaultsFile
}() }()
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
client, err := LoadClient() client, err := LoadClient()
if err != nil { if err != nil {
@ -40,7 +46,7 @@ func TestClient(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err := ioutil.ReadFile("testdata/certs/root_ca.crt") b, err := os.ReadFile("testdata/certs/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -114,7 +120,7 @@ func TestLoadClient(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err := ioutil.ReadFile("testdata/certs/root_ca.crt") b, err := os.ReadFile("testdata/certs/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -140,36 +146,36 @@ func TestLoadClient(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{"ok", func() { {"ok", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, expected, false}, }, expected, false},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/missing.json" IdentityFile = returnInput("testdata/config/missing.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/fail.json" IdentityFile = returnInput("testdata/config/fail.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/missing.json" DefaultsFile = returnInput("testdata/config/missing.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/fail.json" DefaultsFile = returnInput("testdata/config/fail.json")
}, nil, true}, }, nil, true},
{"fail ca", func() { {"fail ca", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badca.json" DefaultsFile = returnInput("testdata/config/badca.json")
}, nil, true}, }, nil, true},
{"fail root", func() { {"fail root", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badroot.json" DefaultsFile = returnInput("testdata/config/badroot.json")
}, nil, true}, }, nil, true},
{"fail type", func() { {"fail type", func() {
IdentityFile = "testdata/config/badIdentity.json" IdentityFile = returnInput("testdata/config/badIdentity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -7,7 +7,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"io/ioutil"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -16,7 +15,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
@ -39,11 +38,18 @@ const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims. // DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute const DefaultLeeway = 1 * time.Minute
// IdentityFile contains the location of the identity file. var (
var IdentityFile = filepath.Join(config.StepPath(), "config", "identity.json") identityDir = step.IdentityPath
configDir = step.ConfigPath
// DefaultsFile contains the location of the defaults file. // IdentityFile contains a pointer to a function that outputs the location of
var DefaultsFile = filepath.Join(config.StepPath(), "config", "defaults.json") // the identity file.
IdentityFile = step.IdentityFile
// DefaultsFile contains a prointer a function that outputs the location of the
// defaults configuration file.
DefaultsFile = step.DefaultsFile
)
// Identity represents the identity file that can be used to authenticate with // Identity represents the identity file that can be used to authenticate with
// the CA. // the CA.
@ -61,7 +67,7 @@ type Identity struct {
// LoadIdentity loads an identity present in the given filename. // LoadIdentity loads an identity present in the given filename.
func LoadIdentity(filename string) (*Identity, error) { func LoadIdentity(filename string) (*Identity, error) {
b, err := ioutil.ReadFile(filename) b, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename) return nil, errors.Wrapf(err, "error reading %s", filename)
} }
@ -74,23 +80,17 @@ func LoadIdentity(filename string) (*Identity, error) {
// LoadDefaultIdentity loads the default identity. // LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) { func LoadDefaultIdentity() (*Identity, error) {
return LoadIdentity(IdentityFile) return LoadIdentity(IdentityFile())
} }
// configDir and identityDir are used in WriteDefaultIdentity for testing
// purposes.
var (
configDir = filepath.Join(config.StepPath(), "config")
identityDir = filepath.Join(config.StepPath(), "identity")
)
// WriteDefaultIdentity writes the given certificates and key and the // WriteDefaultIdentity writes the given certificates and key and the
// identity.json pointing to the new files. // identity.json pointing to the new files.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
if err := os.MkdirAll(configDir, 0700); err != nil { if err := os.MkdirAll(configDir(), 0700); err != nil {
return errors.Wrap(err, "error creating config directory") return errors.Wrap(err, "error creating config directory")
} }
identityDir := identityDir()
if err := os.MkdirAll(identityDir, 0700); err != nil { if err := os.MkdirAll(identityDir, 0700); err != nil {
return errors.Wrap(err, "error creating identity directory") return errors.Wrap(err, "error creating identity directory")
} }
@ -112,7 +112,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
if err := pem.Encode(buf, block); err != nil { if err := pem.Encode(buf, block); err != nil {
return errors.Wrap(err, "error encoding identity key") return errors.Wrap(err, "error encoding identity key")
} }
if err := ioutil.WriteFile(keyFilename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(keyFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }
@ -127,7 +127,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
}); err != nil { }); err != nil {
return errors.Wrap(err, "error writing identity json") return errors.Wrap(err, "error writing identity json")
} }
if err := ioutil.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }
@ -136,7 +136,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
// WriteIdentityCertificate writes the identity certificate to disk. // WriteIdentityCertificate writes the identity certificate to disk.
func WriteIdentityCertificate(certChain []api.Certificate) error { func WriteIdentityCertificate(certChain []api.Certificate) error {
filename := filepath.Join(identityDir, "identity.crt") filename := filepath.Join(identityDir(), "identity.crt")
return writeCertificate(filename, certChain) return writeCertificate(filename, certChain)
} }
@ -153,7 +153,7 @@ func writeCertificate(filename string, certChain []api.Certificate) error {
} }
} }
if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(filename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing certificate") return errors.Wrap(err, "error writing certificate")
} }
@ -263,7 +263,7 @@ func (i *Identity) GetCertPool() (*x509.CertPool, error) {
if i.Root == "" { if i.Root == "" {
return nil, nil return nil, nil
} }
b, err := ioutil.ReadFile(i.Root) b, err := os.ReadFile(i.Root)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error reading identity root") return nil, errors.Wrap(err, "error reading identity root")
} }
@ -319,8 +319,8 @@ func (i *Identity) Renew(client Renewer) error {
return errors.Wrap(err, "error encoding identity certificate") return errors.Wrap(err, "error encoding identity certificate")
} }
} }
certFilename := filepath.Join(identityDir, "identity.crt") certFilename := filepath.Join(identityDir(), "identity.crt")
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }

View file

@ -33,9 +33,9 @@ func TestLoadDefaultIdentity(t *testing.T) {
want *Identity want *Identity
wantErr bool wantErr bool
}{ }{
{"ok", func() { IdentityFile = "testdata/config/identity.json" }, expected, false}, {"ok", func() { IdentityFile = returnInput("testdata/config/identity.json") }, expected, false},
{"fail read", func() { IdentityFile = "testdata/config/missing.json" }, nil, true}, {"fail read", func() { IdentityFile = returnInput("testdata/config/missing.json") }, nil, true},
{"fail unmarshal", func() { IdentityFile = "testdata/config/fail.json" }, nil, true}, {"fail unmarshal", func() { IdentityFile = returnInput("testdata/config/fail.json") }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -217,9 +217,9 @@ func TestWriteDefaultIdentity(t *testing.T) {
certChain = append(certChain, api.Certificate{Certificate: c}) certChain = append(certChain, api.Certificate{Certificate: c})
} }
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(tmpDir, "config", "identity.json") IdentityFile = returnInput(filepath.Join(tmpDir, "config", "identity.json"))
type args struct { type args struct {
certChain []api.Certificate certChain []api.Certificate
@ -233,27 +233,27 @@ func TestWriteDefaultIdentity(t *testing.T) {
}{ }{
{"ok", func() {}, args{certChain, key}, false}, {"ok", func() {}, args{certChain, key}, false},
{"fail mkdir config", func() { {"fail mkdir config", func() {
configDir = filepath.Join(tmpDir, "identity", "identity.crt") configDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail mkdir identity", func() { {"fail mkdir identity", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity", "identity.crt") identityDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail certificate", func() { {"fail certificate", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail key", func() { {"fail key", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, "badKey"}, true}, }, args{certChain, "badKey"}, true},
{"fail write identity", func() { {"fail write identity", func() {
configDir = filepath.Join(tmpDir, "bad-dir") configDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(configDir, "identity.json") IdentityFile = returnInput(filepath.Join(configDir(), "identity.json"))
os.MkdirAll(configDir, 0600) os.MkdirAll(configDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
} }
@ -377,7 +377,7 @@ func TestIdentity_Renew(t *testing.T) {
} }
oldIdentityDir := identityDir oldIdentityDir := identityDir
identityDir = "testdata/identity" identityDir = returnInput("testdata/identity")
defer func() { defer func() {
identityDir = oldIdentityDir identityDir = oldIdentityDir
os.RemoveAll(tmpDir) os.RemoveAll(tmpDir)
@ -432,8 +432,8 @@ func TestIdentity_Renew(t *testing.T) {
{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true}, {"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true}, {"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
{"fail write identity", func() { {"fail write identity", func() {
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true}, }, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -1,8 +1,8 @@
package ca package ca
import ( import (
"io/ioutil"
"net/url" "net/url"
"os"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -45,7 +45,7 @@ func TestNewProvisioner(t *testing.T) {
defer ca.Close() defer ca.Close()
want := getTestProvisioner(t, ca.URL) want := getTestProvisioner(t, ca.URL)
caBundle, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") caBundle, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -18,7 +18,7 @@ var minCertDuration = time.Minute
// TLSRenewer automatically renews a tls certificate using a RenewFunc. // TLSRenewer automatically renews a tls certificate using a RenewFunc.
type TLSRenewer struct { type TLSRenewer struct {
sync.RWMutex renewMutex sync.RWMutex
RenewCertificate RenewFunc RenewCertificate RenewFunc
cert *tls.Certificate cert *tls.Certificate
timer *time.Timer timer *time.Timer
@ -81,9 +81,9 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption
func (r *TLSRenewer) Run() { func (r *TLSRenewer) Run() {
cert := r.getCertificate() cert := r.getCertificate()
next := r.nextRenewDuration(cert.Leaf.NotAfter) next := r.nextRenewDuration(cert.Leaf.NotAfter)
r.Lock() r.renewMutex.Lock()
r.timer = time.AfterFunc(next, r.renewCertificate) r.timer = time.AfterFunc(next, r.renewCertificate)
r.Unlock() r.renewMutex.Unlock()
} }
// RunContext starts the certificate renewer for the given certificate. // RunContext starts the certificate renewer for the given certificate.
@ -133,25 +133,25 @@ func (r *TLSRenewer) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Cer
// if the timer does not fire e.g. when the CA is run from a laptop that // if the timer does not fire e.g. when the CA is run from a laptop that
// enters sleep mode. // enters sleep mode.
func (r *TLSRenewer) getCertificate() *tls.Certificate { func (r *TLSRenewer) getCertificate() *tls.Certificate {
r.RLock() r.renewMutex.RLock()
cert := r.cert cert := r.cert
r.RUnlock() r.renewMutex.RUnlock()
return cert return cert
} }
// getCertificateForCA returns the certificate using a read-only lock. It will // getCertificateForCA returns the certificate using a read-only lock. It will
// automatically renew the certificate if it has expired. // automatically renew the certificate if it has expired.
func (r *TLSRenewer) getCertificateForCA() *tls.Certificate { func (r *TLSRenewer) getCertificateForCA() *tls.Certificate {
r.RLock() r.renewMutex.RLock()
// Force certificate renewal if the timer didn't run. // Force certificate renewal if the timer didn't run.
// This is an special case that can happen after a computer sleep. // This is an special case that can happen after a computer sleep.
if time.Now().After(r.certNotAfter) { if time.Now().After(r.certNotAfter) {
r.RUnlock() r.renewMutex.RUnlock()
r.renewCertificate() r.renewCertificate()
r.RLock() r.renewMutex.RLock()
} }
cert := r.cert cert := r.cert
r.RUnlock() r.renewMutex.RUnlock()
return cert return cert
} }
@ -159,10 +159,10 @@ func (r *TLSRenewer) getCertificateForCA() *tls.Certificate {
// updates certNotAfter with 1m of delta; this will force the renewal of the // updates certNotAfter with 1m of delta; this will force the renewal of the
// certificate if it is about to expire. // certificate if it is about to expire.
func (r *TLSRenewer) setCertificate(cert *tls.Certificate) { func (r *TLSRenewer) setCertificate(cert *tls.Certificate) {
r.Lock() r.renewMutex.Lock()
r.cert = cert r.cert = cert
r.certNotAfter = cert.Leaf.NotAfter.Add(-1 * time.Minute) r.certNotAfter = cert.Leaf.NotAfter.Add(-1 * time.Minute)
r.Unlock() r.renewMutex.Unlock()
} }
func (r *TLSRenewer) renewCertificate() { func (r *TLSRenewer) renewCertificate() {
@ -175,9 +175,9 @@ func (r *TLSRenewer) renewCertificate() {
r.setCertificate(cert) r.setCertificate(cert)
next = r.nextRenewDuration(cert.Leaf.NotAfter) next = r.nextRenewDuration(cert.Leaf.NotAfter)
} }
r.Lock() r.renewMutex.Lock()
r.timer.Reset(next) r.timer.Reset(next)
r.Unlock() r.renewMutex.Unlock()
} }
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration { func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {

View file

@ -4,8 +4,8 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"os"
"reflect" "reflect"
"sort" "sort"
"testing" "testing"
@ -202,7 +202,7 @@ func TestAddRootsToRootCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -256,7 +256,7 @@ func TestAddRootsToClientCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -310,12 +310,12 @@ func TestAddFederationToRootCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -374,12 +374,12 @@ func TestAddFederationToClientCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -438,7 +438,7 @@ func TestAddRootsToCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -492,12 +492,12 @@ func TestAddFederationToCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -8,7 +8,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"io/ioutil" "io"
"log" "log"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -221,7 +221,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err) t.Fatalf("ioutil.RealAdd() error = %v", err)
} }
@ -335,7 +335,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err) t.Errorf("ioutil.RealAdd() error = %v", err)
return return
@ -374,9 +374,9 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err) t.Errorf("io.ReadAll() error = %v", err)
return return
} }
if !bytes.Equal(b, []byte("ok")) { if !bytes.Equal(b, []byte("ok")) {

View file

@ -91,7 +91,7 @@ func mustSerializeCrt(filename string, certs ...*x509.Certificate) {
panic(err) panic(err)
} }
} }
if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(filename, buf.Bytes(), 0600); err != nil {
panic(err) panic(err)
} }
} }
@ -105,7 +105,7 @@ func mustSerializeKey(filename string, key crypto.Signer) {
Type: "PRIVATE KEY", Type: "PRIVATE KEY",
Bytes: b, Bytes: b,
}) })
if err := ioutil.WriteFile(filename, b, 0600); err != nil { if err := os.WriteFile(filename, b, 0600); err != nil {
panic(err) panic(err)
} }
} }

View file

@ -21,7 +21,7 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
"go.step.sm/cli-utils/command" "go.step.sm/cli-utils/command"
"go.step.sm/cli-utils/command/version" "go.step.sm/cli-utils/command/version"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/cli-utils/usage" "go.step.sm/cli-utils/usage"
@ -49,7 +49,7 @@ var (
) )
func init() { func init() {
config.Set("Smallstep CA", Version, BuildTime) step.Set("Smallstep CA", Version, BuildTime)
authority.GlobalVersion.Version = Version authority.GlobalVersion.Version = Version
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
} }
@ -115,7 +115,7 @@ func main() {
app := cli.NewApp() app := cli.NewApp()
app.Name = "step-ca" app.Name = "step-ca"
app.HelpName = "step-ca" app.HelpName = "step-ca"
app.Version = config.Version() app.Version = step.Version()
app.Usage = "an online certificate authority for secure automated certificate management" app.Usage = "an online certificate authority for secure automated certificate management"
app.UsageText = `**step-ca** <config> [**--password-file**=<file>] app.UsageText = `**step-ca** <config> [**--password-file**=<file>]
[**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>] [**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>]

View file

@ -32,7 +32,7 @@ import (
// Config is a mapping of the cli flags. // Config is a mapping of the cli flags.
type Config struct { type Config struct {
KMS string KMS string
RootOnly bool GenerateRoot bool
RootObject string RootObject string
RootKeyObject string RootKeyObject string
RootSubject string RootSubject string
@ -58,22 +58,29 @@ func (c *Config) Validate() error {
switch { switch {
case c.KMS == "": case c.KMS == "":
return errors.New("flag `--kms` is required") return errors.New("flag `--kms` is required")
case c.CrtPath == "":
return errors.New("flag `--crt-cert-path` is required")
case c.RootFile != "" && c.KeyFile == "": case c.RootFile != "" && c.KeyFile == "":
return errors.New("flag `--root` requires flag `--key`") return errors.New("flag `--root-cert-file` requires flag `--root-key-file`")
case c.KeyFile != "" && c.RootFile == "": case c.KeyFile != "" && c.RootFile == "":
return errors.New("flag `--key` requires flag `--root`") return errors.New("flag `--root-key-file` requires flag `--root-cert-file`")
case c.RootOnly && c.RootFile != "":
return errors.New("flag `--root-only` is incompatible with flag `--root`")
case c.RootFile == "" && c.RootObject == "": case c.RootFile == "" && c.RootObject == "":
return errors.New("one of flag `--root` or `--root-cert` is required") return errors.New("one of flag `--root-cert-file` or `--root-cert-obj` is required")
case c.RootFile == "" && c.RootKeyObject == "": case c.KeyFile == "" && c.RootKeyObject == "":
return errors.New("one of flag `--root` or `--root-key` is required") return errors.New("one of flag `--root-key-file` or `--root-key-obj` is required")
case c.CrtKeyPath == "" && c.CrtKeyObject == "":
return errors.New("one of flag `--crt-key-path` or `--crt-key-obj` is required")
case c.RootFile == "" && c.GenerateRoot && c.RootKeyObject == "":
return errors.New("flag `--root-gen` requires flag `--root-key-obj`")
case c.RootFile == "" && c.GenerateRoot && c.RootPath == "":
return errors.New("flag `--root-gen` requires `--root-cert-path`")
default: default:
if c.RootFile != "" { if c.RootFile != "" {
c.GenerateRoot = false
c.RootObject = "" c.RootObject = ""
c.RootKeyObject = "" c.RootKeyObject = ""
} }
if c.RootOnly { if c.CrtKeyPath != "" {
c.CrtObject = "" c.CrtObject = ""
c.CrtKeyObject = "" c.CrtKeyObject = ""
} }
@ -101,21 +108,27 @@ func main() {
var c Config var c Config
flag.StringVar(&c.KMS, "kms", kmsuri, "PKCS #11 URI with the module-path and token to connect to the module.") flag.StringVar(&c.KMS, "kms", kmsuri, "PKCS #11 URI with the module-path and token to connect to the module.")
flag.StringVar(&c.Pin, "pin", "", "PKCS #11 PIN") flag.StringVar(&c.Pin, "pin", "", "PKCS #11 PIN")
flag.StringVar(&c.RootObject, "root-cert", "pkcs11:id=7330;object=root-cert", "PKCS #11 URI with object id and label to store the root certificate.") // Option 1: Generate new root
flag.StringVar(&c.RootPath, "root-cert-path", "root_ca.crt", "Location to write the root certificate.") flag.BoolVar(&c.GenerateRoot, "root-gen", true, "Enable the generation of a root key.")
flag.StringVar(&c.RootKeyObject, "root-key", "pkcs11:id=7330;object=root-key", "PKCS #11 URI with object id and label to store the root key.")
flag.StringVar(&c.RootSubject, "root-name", "PKCS #11 Smallstep Root", "Subject and Issuer of the root certificate.") flag.StringVar(&c.RootSubject, "root-name", "PKCS #11 Smallstep Root", "Subject and Issuer of the root certificate.")
flag.StringVar(&c.CrtObject, "crt-cert", "pkcs11:id=7331;object=intermediate-cert", "PKCS #11 URI with object id and label to store the intermediate certificate.") flag.StringVar(&c.RootObject, "root-cert-obj", "pkcs11:id=7330;object=root-cert", "PKCS #11 URI with object id and label to store the root certificate.")
flag.StringVar(&c.CrtPath, "crt-cert-path", "intermediate_ca.crt", "Location to write the intermediate certificate.") flag.StringVar(&c.RootKeyObject, "root-key-obj", "pkcs11:id=7330;object=root-key", "PKCS #11 URI with object id and label to store the root key.")
flag.StringVar(&c.CrtKeyObject, "crt-key", "pkcs11:id=7331;object=intermediate-key", "PKCS #11 URI with object id and label to store the intermediate certificate.") // Option 2: Read root from disk and sign intermediate
flag.StringVar(&c.RootFile, "root-cert-file", "", "Path to the root certificate to use.")
flag.StringVar(&c.KeyFile, "root-key-file", "", "Path to the root key to use.")
// Option 3: Generate certificate signing request
flag.StringVar(&c.CrtSubject, "crt-name", "PKCS #11 Smallstep Intermediate", "Subject of the intermediate certificate.") flag.StringVar(&c.CrtSubject, "crt-name", "PKCS #11 Smallstep Intermediate", "Subject of the intermediate certificate.")
flag.StringVar(&c.CrtKeyPath, "crt-key-path", "intermediate_ca_key", "Location to write the intermediate private key.") flag.StringVar(&c.CrtObject, "crt-cert-obj", "pkcs11:id=7331;object=intermediate-cert", "PKCS #11 URI with object id and label to store the intermediate certificate.")
flag.StringVar(&c.CrtKeyObject, "crt-key-obj", "pkcs11:id=7331;object=intermediate-key", "PKCS #11 URI with object id and label to store the intermediate certificate.")
// SSH certificates
flag.BoolVar(&c.EnableSSH, "ssh", false, "Enable the creation of ssh keys.")
flag.StringVar(&c.SSHHostKeyObject, "ssh-host-key", "pkcs11:id=7332;object=ssh-host-key", "PKCS #11 URI with object id and label to store the key used to sign SSH host certificates.") flag.StringVar(&c.SSHHostKeyObject, "ssh-host-key", "pkcs11:id=7332;object=ssh-host-key", "PKCS #11 URI with object id and label to store the key used to sign SSH host certificates.")
flag.StringVar(&c.SSHUserKeyObject, "ssh-user-key", "pkcs11:id=7333;object=ssh-user-key", "PKCS #11 URI with object id and label to store the key used to sign SSH user certificates.") flag.StringVar(&c.SSHUserKeyObject, "ssh-user-key", "pkcs11:id=7333;object=ssh-user-key", "PKCS #11 URI with object id and label to store the key used to sign SSH user certificates.")
flag.BoolVar(&c.RootOnly, "root-only", false, "Store only only the root certificate and sign and intermediate.") // Output files
flag.StringVar(&c.RootFile, "root", "", "Path to the root certificate to use.") flag.StringVar(&c.RootPath, "root-cert-path", "root_ca.crt", "Location to write the root certificate.")
flag.StringVar(&c.KeyFile, "key", "", "Path to the root key to use.") flag.StringVar(&c.CrtPath, "crt-cert-path", "intermediate_ca.crt", "Location to write the intermediate certificate.")
flag.BoolVar(&c.EnableSSH, "ssh", false, "Enable the creation of ssh keys.") flag.StringVar(&c.CrtKeyPath, "crt-key-path", "", "Location to write the intermediate private key.")
// Others
flag.BoolVar(&c.NoCerts, "no-certs", false, "Do not store certificates in the module.") flag.BoolVar(&c.NoCerts, "no-certs", false, "Do not store certificates in the module.")
flag.BoolVar(&c.Force, "force", false, "Force the delete of previous keys.") flag.BoolVar(&c.Force, "force", false, "Force the delete of previous keys.")
flag.BoolVar(&c.Extractable, "extractable", false, "Allow export of private keys under wrap.") flag.BoolVar(&c.Extractable, "extractable", false, "Allow export of private keys under wrap.")
@ -276,22 +289,8 @@ func createPKI(k kms.KeyManager, c Config) error {
// Root Certificate // Root Certificate
var signer crypto.Signer var signer crypto.Signer
var root *x509.Certificate var root *x509.Certificate
if c.RootFile != "" && c.KeyFile != "" { switch {
root, err = pemutil.ReadCertificate(c.RootFile) case c.GenerateRoot:
if err != nil {
return err
}
key, err := pemutil.Read(c.KeyFile)
if err != nil {
return err
}
var ok bool
if signer, ok = key.(crypto.Signer); !ok {
return errors.Errorf("key type '%T' does not implement a signer", key)
}
} else {
resp, err := k.CreateKey(&apiv1.CreateKeyRequest{ resp, err := k.CreateKey(&apiv1.CreateKeyRequest{
Name: c.RootKeyObject, Name: c.RootKeyObject,
SignatureAlgorithm: apiv1.ECDSAWithSHA256, SignatureAlgorithm: apiv1.ECDSAWithSHA256,
@ -331,7 +330,7 @@ func createPKI(k kms.KeyManager, c Config) error {
return errors.Wrap(err, "error parsing root certificate") return errors.Wrap(err, "error parsing root certificate")
} }
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts { if cm, ok := k.(kms.CertificateManager); ok && c.RootObject != "" && !c.NoCerts {
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{ if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.RootObject, Name: c.RootObject,
Certificate: root, Certificate: root,
@ -339,6 +338,8 @@ func createPKI(k kms.KeyManager, c Config) error {
}); err != nil { }); err != nil {
return err return err
} }
} else {
c.RootObject = ""
} }
if err := fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{
@ -350,12 +351,31 @@ func createPKI(k kms.KeyManager, c Config) error {
ui.PrintSelected("Root Key", resp.Name) ui.PrintSelected("Root Key", resp.Name)
ui.PrintSelected("Root Certificate", c.RootPath) ui.PrintSelected("Root Certificate", c.RootPath)
if c.RootObject != "" {
ui.PrintSelected("Root Certificate Object", c.RootObject)
}
case c.RootFile != "" && c.KeyFile != "": // Read Root From File
root, err = pemutil.ReadCertificate(c.RootFile)
if err != nil {
return err
}
key, err := pemutil.Read(c.KeyFile)
if err != nil {
return err
}
var ok bool
if signer, ok = key.(crypto.Signer); !ok {
return errors.Errorf("key type '%T' does not implement a signer", key)
}
} }
// Intermediate Certificate // Intermediate Certificate
var keyName string var keyName string
var publicKey crypto.PublicKey var publicKey crypto.PublicKey
if c.RootOnly { var intSigner crypto.Signer
if c.CrtKeyPath != "" {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
return errors.Wrap(err, "error creating intermediate key") return errors.Wrap(err, "error creating intermediate key")
@ -373,6 +393,7 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
publicKey = priv.Public() publicKey = priv.Public()
intSigner = priv
} else { } else {
resp, err := k.CreateKey(&apiv1.CreateKeyRequest{ resp, err := k.CreateKey(&apiv1.CreateKeyRequest{
Name: c.CrtKeyObject, Name: c.CrtKeyObject,
@ -384,8 +405,14 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
publicKey = resp.PublicKey publicKey = resp.PublicKey
keyName = resp.Name keyName = resp.Name
intSigner, err = k.CreateSigner(&resp.CreateSignerRequest)
if err != nil {
return err
}
} }
if root != nil {
template := &x509.Certificate{ template := &x509.Certificate{
IsCA: true, IsCA: true,
NotBefore: now, NotBefore: now,
@ -410,7 +437,7 @@ func createPKI(k kms.KeyManager, c Config) error {
return errors.Wrap(err, "error parsing intermediate certificate") return errors.Wrap(err, "error parsing intermediate certificate")
} }
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts { if cm, ok := k.(kms.CertificateManager); ok && c.CrtObject != "" && !c.NoCerts {
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{ if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.CrtObject, Name: c.CrtObject,
Certificate: intermediate, Certificate: intermediate,
@ -418,6 +445,8 @@ func createPKI(k kms.KeyManager, c Config) error {
}); err != nil { }); err != nil {
return err return err
} }
} else {
c.CrtObject = ""
} }
if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{
@ -426,14 +455,39 @@ func createPKI(k kms.KeyManager, c Config) error {
}), 0600); err != nil { }), 0600); err != nil {
return err return err
} }
} else {
// No root available, generate CSR for external root.
csrTemplate := x509.CertificateRequest{
Subject: pkix.Name{CommonName: c.CrtSubject},
SignatureAlgorithm: x509.ECDSAWithSHA256,
}
// step: generate the csr request
csrCertificate, err := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, intSigner)
if err != nil {
return err
}
if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csrCertificate,
}), 0600); err != nil {
return err
}
}
if c.RootOnly { if c.CrtKeyPath != "" {
ui.PrintSelected("Intermediate Key", c.CrtKeyPath) ui.PrintSelected("Intermediate Key", c.CrtKeyPath)
} else { } else {
ui.PrintSelected("Intermediate Key", keyName) ui.PrintSelected("Intermediate Key", keyName)
} }
if root != nil {
ui.PrintSelected("Intermediate Certificate", c.CrtPath) ui.PrintSelected("Intermediate Certificate", c.CrtPath)
if c.CrtObject != "" {
ui.PrintSelected("Intermediate Certificate Object", c.CrtObject)
}
} else {
ui.PrintSelected("Intermediate Certificate Request", c.CrtPath)
}
if c.SSHHostKeyObject != "" { if c.SSHHostKeyObject != "" {
resp, err := k.CreateKey(&apiv1.CreateKeyRequest{ resp, err := k.CreateKey(&apiv1.CreateKeyRequest{

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -98,7 +97,7 @@ To get a linked authority token:
var password []byte var password []byte
if passFile != "" { if passFile != "" {
if password, err = ioutil.ReadFile(passFile); err != nil { if password, err = os.ReadFile(passFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", passFile)) fatal(errors.Wrapf(err, "error reading %s", passFile))
} }
password = bytes.TrimRightFunc(password, unicode.IsSpace) password = bytes.TrimRightFunc(password, unicode.IsSpace)
@ -106,7 +105,7 @@ To get a linked authority token:
var sshHostPassword []byte var sshHostPassword []byte
if sshHostPassFile != "" { if sshHostPassFile != "" {
if sshHostPassword, err = ioutil.ReadFile(sshHostPassFile); err != nil { if sshHostPassword, err = os.ReadFile(sshHostPassFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", sshHostPassFile)) fatal(errors.Wrapf(err, "error reading %s", sshHostPassFile))
} }
sshHostPassword = bytes.TrimRightFunc(sshHostPassword, unicode.IsSpace) sshHostPassword = bytes.TrimRightFunc(sshHostPassword, unicode.IsSpace)
@ -114,7 +113,7 @@ To get a linked authority token:
var sshUserPassword []byte var sshUserPassword []byte
if sshUserPassFile != "" { if sshUserPassFile != "" {
if sshUserPassword, err = ioutil.ReadFile(sshUserPassFile); err != nil { if sshUserPassword, err = os.ReadFile(sshUserPassFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", sshUserPassFile)) fatal(errors.Wrapf(err, "error reading %s", sshUserPassFile))
} }
sshUserPassword = bytes.TrimRightFunc(sshUserPassword, unicode.IsSpace) sshUserPassword = bytes.TrimRightFunc(sshUserPassword, unicode.IsSpace)
@ -122,7 +121,7 @@ To get a linked authority token:
var issuerPassword []byte var issuerPassword []byte
if issuerPassFile != "" { if issuerPassFile != "" {
if issuerPassword, err = ioutil.ReadFile(issuerPassFile); err != nil { if issuerPassword, err = os.ReadFile(issuerPassFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", issuerPassFile)) fatal(errors.Wrapf(err, "error reading %s", issuerPassFile))
} }
issuerPassword = bytes.TrimRightFunc(issuerPassword, unicode.IsSpace) issuerPassword = bytes.TrimRightFunc(issuerPassword, unicode.IsSpace)

View file

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "os"
"unicode" "unicode"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -72,14 +72,14 @@ func exportAction(ctx *cli.Context) error {
} }
if passwordFile != "" { if passwordFile != "" {
b, err := ioutil.ReadFile(passwordFile) b, err := os.ReadFile(passwordFile)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", passwordFile) return errors.Wrapf(err, "error reading %s", passwordFile)
} }
cfg.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace)) cfg.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace))
} }
if issuerPasswordFile != "" { if issuerPasswordFile != "" {
b, err := ioutil.ReadFile(issuerPasswordFile) b, err := os.ReadFile(issuerPasswordFile)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", issuerPasswordFile) return errors.Wrapf(err, "error reading %s", issuerPasswordFile)
} }

View file

@ -3,7 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil" "io"
"os" "os"
"time" "time"
@ -32,7 +32,7 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -3,7 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil" "io"
"os" "os"
"time" "time"
@ -32,7 +32,7 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
panic(err) panic(err)

15
go.mod
View file

@ -1,6 +1,6 @@
module github.com/smallstep/certificates module github.com/smallstep/certificates
go 1.15 go 1.16
require ( require (
cloud.google.com/go v0.83.0 cloud.google.com/go v0.83.0
@ -29,10 +29,10 @@ require (
github.com/rs/xid v1.2.1 github.com/rs/xid v1.2.1
github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus v1.4.2
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262
github.com/smallstep/nosql v0.3.8 github.com/smallstep/nosql v0.3.9
github.com/urfave/cli v1.22.4 github.com/urfave/cli v1.22.4
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352
go.step.sm/cli-utils v0.6.2 go.step.sm/cli-utils v0.7.0
go.step.sm/crypto v0.13.0 go.step.sm/crypto v0.13.0
go.step.sm/linkedca v0.7.0 go.step.sm/linkedca v0.7.0
golang.org/x/crypto v0.0.0-20210915214749-c084706c2272 golang.org/x/crypto v0.0.0-20210915214749-c084706c2272
@ -44,7 +44,8 @@ require (
gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/square/go-jose.v2 v2.6.0
) )
// replace github.com/smallstep/nosql => ../nosql //replace github.com/smallstep/nosql => ../nosql
// replace go.step.sm/crypto => ../crypto
// replace go.step.sm/cli-utils => ../cli-utils //replace go.step.sm/crypto => ../crypto
// replace go.step.sm/linkedca => ../linkedca
//replace go.step.sm/cli-utils => ../cli-utils

11
go.sum
View file

@ -367,10 +367,8 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
@ -496,8 +494,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd
github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE= github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc=
github.com/smallstep/nosql v0.3.8 h1:1/EWUbbEdz9ai0g9Fd09VekVjtxp+5+gIHpV2PdwW3o= github.com/smallstep/nosql v0.3.9 h1:YPy5PR3PXClqmpFaVv0wfXDXDc7NXGBE1auyU2c87dc=
github.com/smallstep/nosql v0.3.8/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs= github.com/smallstep/nosql v0.3.9/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
@ -561,8 +559,8 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.step.sm/cli-utils v0.6.2 h1:ofa3G/EqE3dTDXmzoXHDZr18qJZoFsKSzbzuF+mxuZU= go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds=
go.step.sm/cli-utils v0.6.2/go.mod h1:0tZ8F2QwLgD6KbKj4nrQZhMakTasEAnOcW3Ekc5pnrA= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E=
go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0=
go.step.sm/crypto v0.13.0 h1:mQuP9Uu2FNmqCJNO0OTbvolnYXzONy4wdUBtUVcP1s8= go.step.sm/crypto v0.13.0 h1:mQuP9Uu2FNmqCJNO0OTbvolnYXzONy4wdUBtUVcP1s8=
go.step.sm/crypto v0.13.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg= go.step.sm/crypto v0.13.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg=
@ -956,7 +954,6 @@ google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"crypto" "crypto"
"fmt" "fmt"
"io/ioutil" "os"
"reflect" "reflect"
"testing" "testing"
@ -165,7 +165,7 @@ func TestCloudKMS_Close(t *testing.T) {
func TestCloudKMS_CreateSigner(t *testing.T) { func TestCloudKMS_CreateSigner(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1" keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -223,7 +223,7 @@ func TestCloudKMS_CreateKey(t *testing.T) {
testError := fmt.Errorf("an error") testError := fmt.Errorf("an error")
alreadyExists := status.Error(codes.AlreadyExists, "already exists") alreadyExists := status.Error(codes.AlreadyExists, "already exists")
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -389,7 +389,7 @@ func TestCloudKMS_GetPublicKey(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1" keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
testError := fmt.Errorf("an error") testError := fmt.Errorf("an error")
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -7,7 +7,7 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io" "io"
"io/ioutil" "os"
"reflect" "reflect"
"testing" "testing"
@ -17,7 +17,7 @@ import (
) )
func Test_newSigner(t *testing.T) { func Test_newSigner(t *testing.T) {
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -70,7 +70,7 @@ func Test_newSigner(t *testing.T) {
} }
func Test_signer_Public(t *testing.T) { func Test_signer_Public(t *testing.T) {
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -159,7 +159,7 @@ func Test_signer_Sign(t *testing.T) {
} }
func TestSigner_SignatureAlgorithm(t *testing.T) { func TestSigner_SignatureAlgorithm(t *testing.T) {
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -11,7 +11,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "os"
"reflect" "reflect"
"testing" "testing"
@ -78,7 +78,7 @@ func TestSoftKMS_CreateSigner(t *testing.T) {
} }
// Read and decode file using standard packages // Read and decode file using standard packages
b, err := ioutil.ReadFile("testdata/priv.pem") b, err := os.ReadFile("testdata/priv.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -234,7 +234,7 @@ func TestSoftKMS_CreateKey(t *testing.T) {
} }
func TestSoftKMS_GetPublicKey(t *testing.T) { func TestSoftKMS_GetPublicKey(t *testing.T) {
b, err := ioutil.ReadFile("testdata/pub.pem") b, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -332,7 +332,7 @@ func TestSoftKMS_CreateDecrypter(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err := ioutil.ReadFile("testdata/rsa.priv.pem") b, err := os.ReadFile("testdata/rsa.priv.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -9,7 +9,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"io/ioutil"
"net" "net"
"os" "os"
"os/exec" "os/exec"
@ -202,7 +201,7 @@ func TestNew(t *testing.T) {
}) })
// Load ssh test fixtures // Load ssh test fixtures
b, err := ioutil.ReadFile("testdata/ssh") b, err := os.ReadFile("testdata/ssh")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -290,7 +289,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) {
} }
// Read and decode file using standard packages // Read and decode file using standard packages
b, err := ioutil.ReadFile("testdata/priv.pem") b, err := os.ReadFile("testdata/priv.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -315,7 +314,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) {
}) })
// Load ssh test fixtures // Load ssh test fixtures
sshPubKeyStr, err := ioutil.ReadFile("testdata/ssh.pub") sshPubKeyStr, err := os.ReadFile("testdata/ssh.pub")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -323,7 +322,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err = ioutil.ReadFile("testdata/ssh") b, err = os.ReadFile("testdata/ssh")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -499,7 +498,7 @@ func TestSSHAgentKMS_CreateKey(t *testing.T) {
*/ */
func TestSSHAgentKMS_GetPublicKey(t *testing.T) { func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
b, err := ioutil.ReadFile("testdata/pub.pem") b, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -510,7 +509,7 @@ func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
} }
// Load ssh test fixtures // Load ssh test fixtures
b, err = ioutil.ReadFile("testdata/ssh.pub") b, err = os.ReadFile("testdata/ssh.pub")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -518,7 +517,7 @@ func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err = ioutil.ReadFile("testdata/ssh") b, err = os.ReadFile("testdata/ssh")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -3,8 +3,8 @@ package uri
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"io/ioutil"
"net/url" "net/url"
"os"
"strings" "strings"
"unicode" "unicode"
@ -140,7 +140,7 @@ func readFile(path string) ([]byte, error) {
if err == nil && (u.Scheme == "" || u.Scheme == "file") && u.Path != "" { if err == nil && (u.Scheme == "" || u.Scheme == "file") && u.Path != "" {
path = u.Path path = u.Path
} }
b, err := ioutil.ReadFile(path) b, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", path) return nil, errors.Wrapf(err, "error reading %s", path)
} }

View file

@ -29,9 +29,9 @@ import (
"github.com/smallstep/certificates/kms" "github.com/smallstep/certificates/kms"
kmsapi "github.com/smallstep/certificates/kms/apiv1" kmsapi "github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/errs"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
@ -87,44 +87,50 @@ const (
) )
// GetDBPath returns the path where the file-system persistence is stored // GetDBPath returns the path where the file-system persistence is stored
// based on the STEPPATH environment variable. // based on the $(step path).
func GetDBPath() string { func GetDBPath() string {
return filepath.Join(config.StepPath(), dbPath) return filepath.Join(step.Path(), dbPath)
} }
// GetConfigPath returns the directory where the configuration files are stored // GetConfigPath returns the directory where the configuration files are stored
// based on the STEPPATH environment variable. // based on the $(step path).
func GetConfigPath() string { func GetConfigPath() string {
return filepath.Join(config.StepPath(), configPath) return filepath.Join(step.Path(), configPath)
}
// GetProfileConfigPath returns the directory where the profile configuration
// files are stored based on the $(step path).
func GetProfileConfigPath() string {
return filepath.Join(step.ProfilePath(), configPath)
} }
// GetPublicPath returns the directory where the public keys are stored based on // GetPublicPath returns the directory where the public keys are stored based on
// the STEPPATH environment variable. // the $(step path).
func GetPublicPath() string { func GetPublicPath() string {
return filepath.Join(config.StepPath(), publicPath) return filepath.Join(step.Path(), publicPath)
} }
// GetSecretsPath returns the directory where the private keys are stored based // GetSecretsPath returns the directory where the private keys are stored based
// on the STEPPATH environment variable. // on the $(step path).
func GetSecretsPath() string { func GetSecretsPath() string {
return filepath.Join(config.StepPath(), privatePath) return filepath.Join(step.Path(), privatePath)
} }
// GetRootCAPath returns the path where the root CA is stored based on the // GetRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable. // $(step path).
func GetRootCAPath() string { func GetRootCAPath() string {
return filepath.Join(config.StepPath(), publicPath, "root_ca.crt") return filepath.Join(step.Path(), publicPath, "root_ca.crt")
} }
// GetOTTKeyPath returns the path where the one-time token key is stored based // GetOTTKeyPath returns the path where the one-time token key is stored based
// on the STEPPATH environment variable. // on the $(step path).
func GetOTTKeyPath() string { func GetOTTKeyPath() string {
return filepath.Join(config.StepPath(), privatePath, "ott_key") return filepath.Join(step.Path(), privatePath, "ott_key")
} }
// GetTemplatesPath returns the path where the templates are stored. // GetTemplatesPath returns the path where the templates are stored.
func GetTemplatesPath() string { func GetTemplatesPath() string {
return filepath.Join(config.StepPath(), templatesPath) return filepath.Join(step.Path(), templatesPath)
} }
// GetProvisioners returns the map of provisioners on the given CA. // GetProvisioners returns the map of provisioners on the given CA.
@ -293,6 +299,7 @@ type PKI struct {
keyManager kmsapi.KeyManager keyManager kmsapi.KeyManager
config string config string
defaults string defaults string
profileDefaults string
ottPublicKey *jose.JSONWebKey ottPublicKey *jose.JSONWebKey
ottPrivateKey *jose.JSONWebEncryption ottPrivateKey *jose.JSONWebEncryption
options *options options *options
@ -300,6 +307,7 @@ type PKI struct {
// New creates a new PKI configuration. // New creates a new PKI configuration.
func New(o apiv1.Options, opts ...Option) (*PKI, error) { func New(o apiv1.Options, opts ...Option) (*PKI, error) {
currentCtx := step.Contexts().GetCurrent()
caService, err := cas.New(context.Background(), o) caService, err := cas.New(context.Background(), o)
if err != nil { if err != nil {
return nil, err return nil, err
@ -358,6 +366,9 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
cfg = GetConfigPath() cfg = GetConfigPath()
// Create directories // Create directories
dirs := []string{public, private, cfg, GetTemplatesPath()} dirs := []string{public, private, cfg, GetTemplatesPath()}
if currentCtx != nil {
dirs = append(dirs, GetProfileConfigPath())
}
for _, name := range dirs { for _, name := range dirs {
if _, err := os.Stat(name); os.IsNotExist(err) { if _, err := os.Stat(name); os.IsNotExist(err) {
if err = os.MkdirAll(name, 0700); err != nil { if err = os.MkdirAll(name, 0700); err != nil {
@ -415,6 +426,10 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
if p.defaults, err = getPath(cfg, "defaults.json"); err != nil { if p.defaults, err = getPath(cfg, "defaults.json"); err != nil {
return nil, err return nil, err
} }
if currentCtx != nil {
p.profileDefaults = currentCtx.ProfileDefaultsFile()
}
if p.config, err = getPath(cfg, "ca.json"); err != nil { if p.config, err = getPath(cfg, "ca.json"); err != nil {
return nil, err return nil, err
} }
@ -944,6 +959,18 @@ func (p *PKI) Save(opt ...ConfigOption) error {
if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil { if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil {
return errs.FileError(err, p.defaults) return errs.FileError(err, p.defaults)
} }
// If we're using contexts then write a blank object to the default profile
// configuration location.
if p.profileDefaults != "" {
if _, err := os.Stat(p.profileDefaults); os.IsNotExist(err) {
// Write with 0600 to be consistent with directories structure.
if err = fileutil.WriteFile(p.profileDefaults, []byte("{}"), 0600); err != nil {
return errs.FileError(err, p.profileDefaults)
}
} else if err != nil {
return errs.FileError(err, p.profileDefaults)
}
}
// Generate and write templates // Generate and write templates
if err := generateTemplates(cfg.Templates); err != nil { if err := generateTemplates(cfg.Templates); err != nil {
@ -958,6 +985,9 @@ func (p *PKI) Save(opt ...ConfigOption) error {
} }
ui.PrintSelected("Default configuration", p.defaults) ui.PrintSelected("Default configuration", p.defaults)
if p.profileDefaults != "" {
ui.PrintSelected("Default profile configuration", p.profileDefaults)
}
ui.PrintSelected("Certificate Authority configuration", p.config) ui.PrintSelected("Certificate Authority configuration", p.config)
if p.options.deploymentType != LinkedDeployment { if p.options.deploymentType != LinkedDeployment {
ui.Println() ui.Println()

View file

@ -6,9 +6,9 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/errs"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
) )
// getTemplates returns all the templates enabled // getTemplates returns all the templates enabled
@ -44,7 +44,7 @@ func generateTemplates(t *templates.Templates) error {
if !ok { if !ok {
return errors.Errorf("template %s does not exists", t.Name) return errors.Errorf("template %s does not exists", t.Name)
} }
if err := fileutil.WriteFile(config.StepAbs(t.TemplatePath), []byte(data), 0644); err != nil { if err := fileutil.WriteFile(step.Abs(t.TemplatePath), []byte(data), 0644); err != nil {
return err return err
} }
} }
@ -53,7 +53,7 @@ func generateTemplates(t *templates.Templates) error {
if !ok { if !ok {
return errors.Errorf("template %s does not exists", t.Name) return errors.Errorf("template %s does not exists", t.Name)
} }
if err := fileutil.WriteFile(config.StepAbs(t.TemplatePath), []byte(data), 0644); err != nil { if err := fileutil.WriteFile(step.Abs(t.TemplatePath), []byte(data), 0644); err != nil {
return err return err
} }
} }

View file

@ -5,7 +5,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -167,7 +166,7 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
return SCEPRequest{}, errors.Errorf("unsupported operation: %s", operation) return SCEPRequest{}, errors.Errorf("unsupported operation: %s", operation)
} }
case http.MethodPost: case http.MethodPost:
body, err := ioutil.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize))
if err != nil { if err != nil {
return SCEPRequest{}, err return SCEPRequest{}, err
} }

View file

@ -2,7 +2,6 @@ package templates
import ( import (
"bytes" "bytes"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -10,8 +9,8 @@ import (
"github.com/Masterminds/sprig/v3" "github.com/Masterminds/sprig/v3"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
) )
// TemplateType defines how a template will be written in disk. // TemplateType defines how a template will be written in disk.
@ -20,6 +19,9 @@ type TemplateType string
const ( const (
// Snippet will mark a template as a part of a file. // Snippet will mark a template as a part of a file.
Snippet TemplateType = "snippet" Snippet TemplateType = "snippet"
// PrependLine is a template for prepending a single line to a file. If the
// line already exists in the file it will be removed first.
PrependLine TemplateType = "prepend-line"
// File will mark a templates as a full file. // File will mark a templates as a full file.
File TemplateType = "file" File TemplateType = "file"
// Directory will mark a template as a directory. // Directory will mark a template as a directory.
@ -99,7 +101,7 @@ func (t *SSHTemplates) Validate() (err error) {
return return
} }
// Template represents on template file. // Template represents a template file.
type Template struct { type Template struct {
*template.Template *template.Template
Name string `json:"name"` Name string `json:"name"`
@ -118,8 +120,8 @@ func (t *Template) Validate() error {
return nil return nil
case t.Name == "": case t.Name == "":
return errors.New("template name cannot be empty") return errors.New("template name cannot be empty")
case t.Type != Snippet && t.Type != File && t.Type != Directory: case t.Type != Snippet && t.Type != File && t.Type != Directory && t.Type != PrependLine:
return errors.Errorf("invalid template type %s, it must be %s, %s, or %s", t.Type, Snippet, File, Directory) return errors.Errorf("invalid template type %s, it must be %s, %s, %s, or %s", t.Type, Snippet, PrependLine, File, Directory)
case t.TemplatePath == "" && t.Type != Directory && len(t.Content) == 0: case t.TemplatePath == "" && t.Type != Directory && len(t.Content) == 0:
return errors.New("template template cannot be empty") return errors.New("template template cannot be empty")
case t.TemplatePath != "" && t.Type == Directory: case t.TemplatePath != "" && t.Type == Directory:
@ -132,7 +134,7 @@ func (t *Template) Validate() error {
if t.TemplatePath != "" { if t.TemplatePath != "" {
// Check for file // Check for file
st, err := os.Stat(config.StepAbs(t.TemplatePath)) st, err := os.Stat(step.Abs(t.TemplatePath))
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", t.TemplatePath) return errors.Wrapf(err, "error reading %s", t.TemplatePath)
} }
@ -166,8 +168,8 @@ func (t *Template) Load() error {
if t.Template == nil && t.Type != Directory { if t.Template == nil && t.Type != Directory {
switch { switch {
case t.TemplatePath != "": case t.TemplatePath != "":
filename := config.StepAbs(t.TemplatePath) filename := step.Abs(t.TemplatePath)
b, err := ioutil.ReadFile(filename) b, err := os.ReadFile(filename)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", filename) return errors.Wrapf(err, "error reading %s", filename)
} }
@ -247,7 +249,10 @@ type Output struct {
// Write writes the Output to the filesystem as a directory, file or snippet. // Write writes the Output to the filesystem as a directory, file or snippet.
func (o *Output) Write() error { func (o *Output) Write() error {
path := config.StepAbs(o.Path) // Replace ${STEPPATH} with the base step path.
o.Path = strings.ReplaceAll(o.Path, "${STEPPATH}", step.BasePath())
path := step.Abs(o.Path)
if o.Type == Directory { if o.Type == Directory {
return mkdir(path, 0700) return mkdir(path, 0700)
} }
@ -257,11 +262,17 @@ func (o *Output) Write() error {
return err return err
} }
if o.Type == File { switch o.Type {
case File:
return fileutil.WriteFile(path, o.Content, 0600) return fileutil.WriteFile(path, o.Content, 0600)
} case Snippet:
return fileutil.WriteSnippet(path, o.Content, 0600) return fileutil.WriteSnippet(path, o.Content, 0600)
case PrependLine:
return fileutil.PrependLine(path, o.Content, 0600)
default:
// Default to using a Snippet type if the type is not known.
return fileutil.WriteSnippet(path, o.Content, 0600)
}
} }
func mkdir(path string, perm os.FileMode) error { func mkdir(path string, perm os.FileMode) error {

View file

@ -4,6 +4,10 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// SSHTemplateVersionKey is a key that can be submitted by a client to select
// the template version that will be returned by the server.
var SSHTemplateVersionKey = "StepSSHTemplateVersion"
// Step represents the default variables available in the CA. // Step represents the default variables available in the CA.
type Step struct { type Step struct {
SSH StepSSH SSH StepSSH
@ -22,16 +26,23 @@ type StepSSH struct {
var DefaultSSHTemplates = SSHTemplates{ var DefaultSSHTemplates = SSHTemplates{
User: []Template{ User: []Template{
{ {
Name: "include.tpl", Name: "config.tpl",
Type: Snippet, Type: Snippet,
TemplatePath: "templates/ssh/include.tpl", TemplatePath: "templates/ssh/config.tpl",
Path: "~/.ssh/config", Path: "~/.ssh/config",
Comment: "#", Comment: "#",
}, },
{ {
Name: "config.tpl", Name: "step_includes.tpl",
Type: PrependLine,
TemplatePath: "templates/ssh/step_includes.tpl",
Path: "${STEPPATH}/ssh/includes",
Comment: "#",
},
{
Name: "step_config.tpl",
Type: File, Type: File,
TemplatePath: "templates/ssh/config.tpl", TemplatePath: "templates/ssh/step_config.tpl",
Path: "ssh/config", Path: "ssh/config",
Comment: "#", Comment: "#",
}, },
@ -64,30 +75,43 @@ var DefaultSSHTemplates = SSHTemplates{
// DefaultSSHTemplateData contains the data of the default templates used on ssh. // DefaultSSHTemplateData contains the data of the default templates used on ssh.
var DefaultSSHTemplateData = map[string]string{ var DefaultSSHTemplateData = map[string]string{
// include.tpl adds the step ssh config file. // config.tpl adds the step ssh config file.
// //
// Note: on windows `Include C:\...` is treated as a relative path. // Note: on windows `Include C:\...` is treated as a relative path.
"include.tpl": `Host * "config.tpl": `Host *
{{- if or .User.GOOS "none" | eq "windows" }} {{- if or .User.GOOS "none" | eq "windows" }}
Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config" {{- if .User.StepBasePath }}
Include "{{ .User.StepBasePath | replace "\\" "/" | trimPrefix "C:" }}/ssh/includes"
{{- else }} {{- else }}
Include "{{.User.StepPath}}/ssh/config" Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/includes"
{{- end }}
{{- else }}
{{- if .User.StepBasePath }}
Include "{{.User.StepBasePath}}/ssh/includes"
{{- else }}
Include "{{.User.StepPath}}/ssh/includes"
{{- end }}
{{- end }}`, {{- end }}`,
// config.tpl is the step ssh config file, it includes the Match rule and // step_includes.tpl adds the step ssh config file.
//
// Note: on windows `Include C:\...` is treated as a relative path.
"step_includes.tpl": `{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}`,
// step_config.tpl is the step ssh config file, it includes the Match rule and
// references the step known_hosts file. // references the step known_hosts file.
// //
// Note: on windows ProxyCommand requires the full path // Note: on windows ProxyCommand requires the full path
"config.tpl": `Match exec "step ssh check-host %h" "step_config.tpl": `Match exec "step ssh check-host{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %h"
{{- if .User.User }} {{- if .User.User }}
User {{.User.User}} User {{.User.User}}
{{- end }} {{- end }}
{{- if or .User.GOOS "none" | eq "windows" }} {{- if or .User.GOOS "none" | eq "windows" }}
UserKnownHostsFile "{{.User.StepPath}}\ssh\known_hosts" UserKnownHostsFile "{{.User.StepPath}}\ssh\known_hosts"
ProxyCommand C:\Windows\System32\cmd.exe /c step ssh proxycommand %r %h %p ProxyCommand C:\Windows\System32\cmd.exe /c step ssh proxycommand{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %r %h %p
{{- else }} {{- else }}
UserKnownHostsFile "{{.User.StepPath}}/ssh/known_hosts" UserKnownHostsFile "{{.User.StepPath}}/ssh/known_hosts"
ProxyCommand step ssh proxycommand %r %h %p ProxyCommand step ssh proxycommand{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %r %h %p
{{- end }} {{- end }}
`, `,