Merge branch 'master' into hs/acme-revocation

This commit is contained in:
Herman Slatman 2021-11-19 17:00:18 +01:00
commit 2d50c96d99
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
62 changed files with 643 additions and 459 deletions

View file

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: [ '1.15', '1.16', '1.17' ] go: [ '1.16', '1.17' ]
steps: steps:
- -
name: Checkout name: Checkout
@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.41.0' version: 'v1.43.0'
# Optional: working directory, useful for monorepos # Optional: working directory, useful for monorepos
# working-directory: somedir # working-directory: somedir
@ -58,6 +58,7 @@ jobs:
run: V=1 make ci run: V=1 make ci
- -
name: Codecov name: Codecov
if: matrix.go == '1.17'
uses: codecov/codecov-action@v1.2.1 uses: codecov/codecov-action@v1.2.1
with: with:
file: ./coverage.out # optional file: ./coverage.out # optional

View file

@ -73,9 +73,3 @@ issues:
- error strings should not be capitalized or end with punctuation or a newline - error strings should not be capitalized or end with punctuation or a newline
- Wrapf call needs 1 arg but has 2 args - Wrapf call needs 1 arg but has 2 args
- cs.NegotiatedProtocolIsMutual is deprecated - cs.NegotiatedProtocolIsMutual is deprecated
# golangci.com configuration
# https://github.com/golangci/golangci/wiki/Configuration
service:
golangci-lint-version: 1.19.x # use the fixed version to not introduce new linters unexpectedly
prepare:
- echo "here I can run custom commands, but no preparation needed for this repo"

View file

@ -4,15 +4,23 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased - 0.17.7] - DATE ## [Unreleased - 0.18.1] - DATE
### Added ### Added
- Support for generate extractable keys and certificates on a pkcs#11 module.
### Changed ### Changed
### Deprecated ### Deprecated
### Removed ### Removed
### Fixed ### Fixed
### Security ### Security
## [0.18.0] - 2021-11-17
### Added
- Support for multiple certificate authority contexts.
- Support for generating extractable keys and certificates on a pkcs#11 module.
### Changed
- Support two latest versions of golang (1.16, 1.17)
### Deprecated
- go 1.15 support
## [0.17.6] - 2021-10-20 ## [0.17.6] - 2021-10-20
### Notes ### Notes
- 0.17.5 failed in CI/CD - 0.17.5 failed in CI/CD

View file

@ -5,7 +5,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
@ -263,7 +263,7 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -468,7 +468,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -668,7 +668,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -17,7 +17,7 @@ import (
) )
func link(url, typ string) string { func link(url, typ string) string {
return fmt.Sprintf("<%s>;rel=\"%s\"", url, typ) return fmt.Sprintf("<%s>;rel=%q", url, typ)
} }
// Clock that returns time in UTC rounded to seconds. // Clock that returns time in UTC rounded to seconds.

View file

@ -7,7 +7,7 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -89,7 +89,7 @@ func TestHandler_GetDirectory(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -261,7 +261,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -404,7 +404,7 @@ func TestHandler_GetCertificate(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -660,7 +660,7 @@ func TestHandler_GetChallenge(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"crypto/rsa" "crypto/rsa"
"errors" "errors"
"io/ioutil" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -118,7 +118,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct. // parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
func (h *Handler) parseJWS(next nextHTTP) nextHTTP { func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body")) api.WriteError(w, acme.WrapErrorISE(err, "failed to read request body"))
return return
@ -413,7 +413,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
} }
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{ ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{
value: payload, value: payload,
isPostAsGet: string(payload) == "", isPostAsGet: len(payload) == 0,
isEmptyJSON: string(payload) == "{}", isEmptyJSON: string(payload) == "{}",
}) })
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))

View file

@ -8,7 +8,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -148,7 +147,7 @@ func TestHandler_addNonce(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -205,7 +204,7 @@ func TestHandler_addDirLink(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -332,7 +331,7 @@ func TestHandler_verifyContentType(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -400,7 +399,7 @@ func TestHandler_isPostAsGet(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -490,7 +489,7 @@ func TestHandler_parseJWS(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -689,7 +688,7 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -891,7 +890,7 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1087,7 +1086,7 @@ func TestHandler_extractJWK(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1454,7 +1453,7 @@ func TestHandler_validateJWS(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -7,7 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect" "reflect"
@ -430,7 +430,7 @@ func TestHandler_GetOrder(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1343,7 +1343,7 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1633,7 +1633,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -12,7 +12,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -92,7 +92,7 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
"error doing http GET for url %s with status code %d", u, resp.StatusCode)) "error doing http GET for url %s with status code %d", u, resp.StatusCode))
} }
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return WrapErrorISE(err, "error reading "+ return WrapErrorISE(err, "error reading "+
"response body for url %s", u) "response body for url %s", u)

View file

@ -15,7 +15,6 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math/big" "math/big"
"net" "net"
"net/http" "net/http"
@ -707,7 +706,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
}, },
}, },
@ -733,7 +732,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
}, },
}, },
@ -775,7 +774,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
}, },
}, },
@ -818,7 +817,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
}, },
}, },
@ -860,7 +859,7 @@ func TestHTTP01Validate(t *testing.T) {
vo: &ValidateChallengeOptions{ vo: &ValidateChallengeOptions{
HTTPGet: func(url string) (*http.Response, error) { HTTPGet: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: ioutil.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
}, },
}, },

View file

@ -16,7 +16,7 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io"
"math/big" "math/big"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -788,7 +788,7 @@ func Test_caHandler_Health(t *testing.T) {
t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode) t.Errorf("caHandler.Health StatusCode = %d, wants 200", res.StatusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Health unexpected error = %v", err) t.Errorf("caHandler.Health unexpected error = %v", err)
@ -829,7 +829,7 @@ func Test_caHandler_Root(t *testing.T) {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err) t.Errorf("caHandler.Root unexpected error = %v", err)
@ -902,7 +902,7 @@ func Test_caHandler_Sign(t *testing.T) {
t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Root StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Root unexpected error = %v", err) t.Errorf("caHandler.Root unexpected error = %v", err)
@ -954,7 +954,7 @@ func Test_caHandler_Renew(t *testing.T) {
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Renew unexpected error = %v", err) t.Errorf("caHandler.Renew unexpected error = %v", err)
@ -1015,7 +1015,7 @@ func Test_caHandler_Rekey(t *testing.T) {
t.Errorf("caHandler.Rekey StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Rekey StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Rekey unexpected error = %v", err) t.Errorf("caHandler.Rekey unexpected error = %v", err)
@ -1038,12 +1038,12 @@ func Test_caHandler_Provisioners(t *testing.T) {
r *http.Request r *http.Request
} }
req, err := http.NewRequest("GET", "http://example.com/provisioners?cursor=foo&limit=20", nil) req, err := http.NewRequest("GET", "http://example.com/provisioners?cursor=foo&limit=20", http.NoBody)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
reqLimitFail, err := http.NewRequest("GET", "http://example.com/provisioners?limit=abc", nil) reqLimitFail, err := http.NewRequest("GET", "http://example.com/provisioners?limit=abc", http.NoBody)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1105,7 +1105,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Provisioners unexpected error = %v", err) t.Errorf("caHandler.Provisioners unexpected error = %v", err)
@ -1175,7 +1175,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Provisioners StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Provisioners unexpected error = %v", err) t.Errorf("caHandler.Provisioners unexpected error = %v", err)
@ -1225,7 +1225,7 @@ func Test_caHandler_Roots(t *testing.T) {
t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Roots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Roots unexpected error = %v", err) t.Errorf("caHandler.Roots unexpected error = %v", err)
@ -1271,7 +1271,7 @@ func Test_caHandler_Federation(t *testing.T) {
t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.Federation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Federation unexpected error = %v", err) t.Errorf("caHandler.Federation unexpected error = %v", err)

View file

@ -6,7 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -233,7 +233,7 @@ func Test_caHandler_Revoke(t *testing.T) {
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -10,7 +10,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
@ -299,14 +299,14 @@ func Test_caHandler_SSHSign(t *testing.T) {
body []byte body []byte
statusCode int statusCode int
}{ }{
{"ok-user", userReq, nil, user, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, userB64)), http.StatusCreated}, {"ok-user", userReq, nil, user, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q}`, userB64)), http.StatusCreated},
{"ok-host", hostReq, nil, host, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s"}`, hostB64)), http.StatusCreated}, {"ok-host", hostReq, nil, host, nil, nil, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q}`, hostB64)), http.StatusCreated},
{"ok-user-add", userAddReq, nil, user, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":"%s","addUserCrt":"%s"}`, userB64, userB64)), http.StatusCreated}, {"ok-user-add", userAddReq, nil, user, nil, user, nil, nil, nil, []byte(fmt.Sprintf(`{"crt":%q,"addUserCrt":%q}`, userB64, userB64)), http.StatusCreated},
{"ok-user-identity", userIdentityReq, nil, user, nil, user, nil, identityCerts, nil, []byte(fmt.Sprintf(`{"crt":"%s","identityCrt":[%s]}`, userB64, identityCertsPEM)), http.StatusCreated}, {"ok-user-identity", userIdentityReq, nil, user, nil, user, nil, identityCerts, nil, []byte(fmt.Sprintf(`{"crt":%q,"identityCrt":[%s]}`, userB64, identityCertsPEM)), http.StatusCreated},
{"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-body", []byte("bad-json"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-validate", []byte("{}"), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-publicKey", []byte(`{"publicKey":"Zm9v","ott":"ott"}`), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":"%s","ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest}, {"fail-publicKey", []byte(fmt.Sprintf(`{"publicKey":%q,"ott":"ott","addUserPublicKey":"Zm9v"}`, base64.StdEncoding.EncodeToString(user.Key.Marshal()))), nil, nil, nil, nil, nil, nil, nil, nil, http.StatusBadRequest},
{"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, nil, nil, http.StatusUnauthorized}, {"fail-authorize", userReq, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, nil, nil, http.StatusUnauthorized},
{"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusForbidden}, {"fail-signSSH", userReq, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, nil, nil, http.StatusForbidden},
{"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden}, {"fail-SignSSHAddUser", userAddReq, nil, user, nil, nil, fmt.Errorf("an-error"), nil, nil, nil, http.StatusForbidden},
@ -338,7 +338,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
t.Errorf("caHandler.SignSSH StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SignSSH StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SignSSH unexpected error = %v", err) t.Errorf("caHandler.SignSSH unexpected error = %v", err)
@ -368,10 +368,10 @@ func Test_caHandler_SSHRoots(t *testing.T) {
body []byte body []byte
statusCode int statusCode int
}{ }{
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK}, {"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q],"hostKey":[%q]}`, userB64, hostB64)), http.StatusOK},
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK}, {"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q,%q],"hostKey":[%q,%q]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK}, {"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q]}`, userB64)), http.StatusOK},
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK}, {"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":[%q]}`, hostB64)), http.StatusOK},
{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound}, {"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
} }
@ -392,7 +392,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
t.Errorf("caHandler.SSHRoots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHRoots StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHRoots unexpected error = %v", err) t.Errorf("caHandler.SSHRoots unexpected error = %v", err)
@ -422,10 +422,10 @@ func Test_caHandler_SSHFederation(t *testing.T) {
body []byte body []byte
statusCode int statusCode int
}{ }{
{"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"],"hostKey":["%s"]}`, userB64, hostB64)), http.StatusOK}, {"ok", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}, UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q],"hostKey":[%q]}`, userB64, hostB64)), http.StatusOK},
{"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s","%s"],"hostKey":["%s","%s"]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK}, {"many", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host, host}, UserKeys: []ssh.PublicKey{user, user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q,%q],"hostKey":[%q,%q]}`, userB64, userB64, hostB64, hostB64)), http.StatusOK},
{"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":["%s"]}`, userB64)), http.StatusOK}, {"user", &authority.SSHKeys{UserKeys: []ssh.PublicKey{user}}, nil, []byte(fmt.Sprintf(`{"userKey":[%q]}`, userB64)), http.StatusOK},
{"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":["%s"]}`, hostB64)), http.StatusOK}, {"host", &authority.SSHKeys{HostKeys: []ssh.PublicKey{host}}, nil, []byte(fmt.Sprintf(`{"hostKey":[%q]}`, hostB64)), http.StatusOK},
{"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound}, {"empty", &authority.SSHKeys{}, nil, nil, http.StatusNotFound},
{"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError}, {"error", nil, fmt.Errorf("an error"), nil, http.StatusInternalServerError},
} }
@ -446,7 +446,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
t.Errorf("caHandler.SSHFederation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHFederation StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHFederation unexpected error = %v", err) t.Errorf("caHandler.SSHFederation unexpected error = %v", err)
@ -506,7 +506,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
t.Errorf("caHandler.SSHConfig StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHConfig StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHConfig unexpected error = %v", err) t.Errorf("caHandler.SSHConfig unexpected error = %v", err)
@ -553,7 +553,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
t.Errorf("caHandler.SSHCheckHost StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHCheckHost StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHCheckHost unexpected error = %v", err) t.Errorf("caHandler.SSHCheckHost unexpected error = %v", err)
@ -604,7 +604,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
t.Errorf("caHandler.SSHGetHosts StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHGetHosts StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHGetHosts unexpected error = %v", err) t.Errorf("caHandler.SSHGetHosts unexpected error = %v", err)
@ -659,7 +659,7 @@ func Test_caHandler_SSHBastion(t *testing.T) {
t.Errorf("caHandler.SSHBastion StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) t.Errorf("caHandler.SSHBastion StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close() res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.SSHBastion unexpected error = %v", err) t.Errorf("caHandler.SSHBastion unexpected error = %v", err)

View file

@ -3,7 +3,6 @@ package api
import ( import (
"encoding/json" "encoding/json"
"io" "io"
"io/ioutil"
"log" "log"
"net/http" "net/http"
@ -102,7 +101,7 @@ func ReadJSON(r io.Reader, v interface{}) error {
// ReadProtoJSON reads JSON from the request body and stores it in the value // ReadProtoJSON reads JSON from the request body and stores it in the value
// pointed by v. // pointed by v.
func ReadProtoJSON(r io.Reader, m proto.Message) error { func ReadProtoJSON(r io.Reader, m proto.Message) error {
data, err := ioutil.ReadAll(r) data, err := io.ReadAll(r)
if err != nil { if err != nil {
return errs.Wrap(http.StatusBadRequest, err, "error reading request body") return errs.Wrap(http.StatusBadRequest, err, "error reading request body")
} }

View file

@ -7,8 +7,8 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -195,7 +195,7 @@ func TestAuthority_GetDatabase(t *testing.T) {
} }
func TestNewEmbedded(t *testing.T) { func TestNewEmbedded(t *testing.T) {
caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") caPEM, err := os.ReadFile("testdata/certs/root_ca.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt")
@ -268,7 +268,7 @@ func TestNewEmbedded(t *testing.T) {
} }
func TestNewEmbedded_Sign(t *testing.T) { func TestNewEmbedded_Sign(t *testing.T) {
caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") caPEM, err := os.ReadFile("testdata/certs/root_ca.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt")
@ -294,7 +294,7 @@ func TestNewEmbedded_Sign(t *testing.T) {
} }
func TestNewEmbedded_GetTLSCertificate(t *testing.T) { func TestNewEmbedded_GetTLSCertificate(t *testing.T) {
caPEM, err := ioutil.ReadFile("testdata/certs/root_ca.crt") caPEM, err := os.ReadFile("testdata/certs/root_ca.crt")
assert.FatalError(t, err) assert.FatalError(t, err)
crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt") crt, err := pemutil.ReadCertificate("testdata/certs/intermediate_ca.crt")

View file

@ -2,14 +2,14 @@ package authority
import ( import (
"encoding/json" "encoding/json"
"io/ioutil"
"net/url" "net/url"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
) )
@ -245,7 +245,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
return "" return ""
} }
stepPath := filepath.ToSlash(config.StepPath()) stepPath := filepath.ToSlash(step.Path())
if !strings.HasSuffix(stepPath, "/") { if !strings.HasSuffix(stepPath, "/") {
stepPath += "/" stepPath += "/"
} }
@ -257,7 +257,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
panic(err) panic(err)
} }
if ok { if ok {
b, err := ioutil.ReadFile(config.StepAbs(fn)) b, err := os.ReadFile(step.Abs(fn))
if err != nil { if err != nil {
panic(errors.Wrapf(err, "error reading %s", fn)) panic(errors.Wrapf(err, "error reading %s", fn))
} }

View file

@ -9,9 +9,10 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "io"
"net" "net"
"net/http" "net/http"
"os"
"strings" "strings"
"time" "time"
@ -165,7 +166,7 @@ func newAWSConfig(certPath string) (*awsConfig, error) {
if certPath == "" { if certPath == "" {
certBytes = []byte(awsCertificate) certBytes = []byte(awsCertificate)
} else { } else {
if b, err := ioutil.ReadFile(certPath); err == nil { if b, err := os.ReadFile(certPath); err == nil {
certBytes = b certBytes = b
} else { } else {
return nil, errors.Wrapf(err, "error reading %s", certPath) return nil, errors.Wrapf(err, "error reading %s", certPath)
@ -569,7 +570,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
client := http.Client{} client := http.Client{}
// first get the token // first get the token
req, err := http.NewRequest(http.MethodPut, p.config.tokenURL, nil) req, err := http.NewRequest(http.MethodPut, p.config.tokenURL, http.NoBody)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -582,7 +583,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode) return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode)
} }
token, err := ioutil.ReadAll(resp.Body) token, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -602,7 +603,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
func (p *AWS) readResponseBody(resp *http.Response) ([]byte, error) { func (p *AWS) readResponseBody(resp *http.Response) ([]byte, error) {
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,7 +6,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"io/ioutil" "io"
"net/http" "net/http"
"regexp" "regexp"
"strings" "strings"
@ -173,7 +173,7 @@ func (p *Azure) GetIdentityToken(subject, caURL string) (string, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error reading identity token response") return "", errors.Wrap(err, "error reading identity token response")
} }

View file

@ -107,7 +107,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
w.Write([]byte(t1)) w.Write([]byte(t1))
default: default:
w.Header().Add("Content-Type", "application/json") w.Header().Add("Content-Type", "application/json")
w.Write([]byte(fmt.Sprintf(`{"access_token":"%s"}`, t1))) fmt.Fprintf(w, `{"access_token":"%s"}`, t1)
} }
})) }))
defer srv.Close() defer srv.Close()

View file

@ -7,7 +7,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -183,7 +183,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
return "", errors.Wrap(err, "error doing identity request, are you in a GCP VM?") return "", errors.Wrap(err, "error doing identity request, are you in a GCP VM?")
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", errors.Wrap(err, "error on identity request") return "", errors.Wrap(err, "error on identity request")
} }

View file

@ -8,12 +8,15 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
"encoding/json" "encoding/json"
"fmt"
"net" "net"
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
) )
@ -369,8 +372,19 @@ func newValidityValidator(min, max time.Duration) *validityValidator {
return &validityValidator{min: min, max: max} return &validityValidator{min: min, max: max}
} }
// TODO(mariano): refactor errs package to allow sending real errors to the
// user.
func badRequest(format string, args ...interface{}) error {
msg := fmt.Sprintf(format, args...)
return &errs.Error{
Status: http.StatusBadRequest,
Msg: msg,
Err: errors.New(msg),
}
}
// Valid validates the certificate validity settings (notBefore/notAfter) and // Valid validates the certificate validity settings (notBefore/notAfter) and
// and total duration. // total duration.
func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error { func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
var ( var (
na = cert.NotAfter.Truncate(time.Second) na = cert.NotAfter.Truncate(time.Second)
@ -381,22 +395,20 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
d := na.Sub(nb) d := na.Sub(nb)
if na.Before(now) { if na.Before(now) {
return errors.Errorf("notAfter cannot be in the past; na=%v", na) return badRequest("notAfter cannot be in the past; na=%v", na)
} }
if na.Before(nb) { if na.Before(nb) {
return errors.Errorf("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb) return badRequest("notAfter cannot be before notBefore; na=%v, nb=%v", na, nb)
} }
if d < v.min { if d < v.min {
return errors.Errorf("requested duration of %v is less than the authorized minimum certificate duration of %v", return badRequest("requested duration of %v is less than the authorized minimum certificate duration of %v", d, v.min)
d, v.min)
} }
// NOTE: this check is not "technically correct". We're allowing the max // NOTE: this check is not "technically correct". We're allowing the max
// duration of a cert to be "max + backdate" and not all certificates will // duration of a cert to be "max + backdate" and not all certificates will
// be backdated (e.g. if a user passes the NotBefore value then we do not // be backdated (e.g. if a user passes the NotBefore value then we do not
// apply a backdate). This is good enough. // apply a backdate). This is good enough.
if d > v.max+o.Backdate { if d > v.max+o.Backdate {
return errors.Errorf("requested duration of %v is more than the authorized maximum certificate duration of %v", return badRequest("requested duration of %v is more than the authorized maximum certificate duration of %v", d, v.max+o.Backdate)
d, v.max+o.Backdate)
} }
return nil return nil
} }

View file

@ -335,11 +335,11 @@ type sshCertValidityValidator struct {
func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error { func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOptions) error {
switch { switch {
case cert.ValidAfter == 0: case cert.ValidAfter == 0:
return errors.New("ssh certificate validAfter cannot be 0") return badRequest("ssh certificate validAfter cannot be 0")
case cert.ValidBefore < uint64(now().Unix()): case cert.ValidBefore < uint64(now().Unix()):
return errors.New("ssh certificate validBefore cannot be in the past") return badRequest("ssh certificate validBefore cannot be in the past")
case cert.ValidBefore < cert.ValidAfter: case cert.ValidBefore < cert.ValidAfter:
return errors.New("ssh certificate validBefore cannot be before validAfter") return badRequest("ssh certificate validBefore cannot be before validAfter")
} }
var min, max time.Duration var min, max time.Duration
@ -351,9 +351,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
min = v.MinHostSSHCertDuration() min = v.MinHostSSHCertDuration()
max = v.MaxHostSSHCertDuration() max = v.MaxHostSSHCertDuration()
case 0: case 0:
return errors.New("ssh certificate type has not been set") return badRequest("ssh certificate type has not been set")
default: default:
return errors.Errorf("unknown ssh certificate type %d", cert.CertType) return badRequest("unknown ssh certificate type %d", cert.CertType)
} }
// To not take into account the backdate, time.Now() will be used to // To not take into account the backdate, time.Now() will be used to
@ -362,11 +362,9 @@ func (v *sshCertValidityValidator) Valid(cert *ssh.Certificate, opts SignSSHOpti
switch { switch {
case dur < min: case dur < min:
return errors.Errorf("requested duration of %s is less than minimum "+ return badRequest("requested duration of %s is less than minimum accepted duration for selected provisioner of %s", dur, min)
"accepted duration for selected provisioner of %s", dur, min)
case dur > max+opts.Backdate: case dur > max+opts.Backdate:
return errors.Errorf("requested duration of %s is greater than maximum "+ return badRequest("requested duration of %s is greater than maximum accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
"accepted duration for selected provisioner of %s", dur, max+opts.Backdate)
default: default:
return nil return nil
} }

View file

@ -10,9 +10,9 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"strings" "strings"
"time" "time"
@ -188,7 +188,7 @@ func generateJWK() (*JWK, error) {
} }
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
fooPubB, err := ioutil.ReadFile("./testdata/certs/foo.pub") fooPubB, err := os.ReadFile("./testdata/certs/foo.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -196,7 +196,7 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
barPubB, err := ioutil.ReadFile("./testdata/certs/bar.pub") barPubB, err := os.ReadFile("./testdata/certs/bar.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -234,7 +234,7 @@ func generateSSHPOP() (*SSHPOP, error) {
return nil, err return nil, err
} }
userB, err := ioutil.ReadFile("./testdata/certs/ssh_user_ca_key.pub") userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -242,7 +242,7 @@ func generateSSHPOP() (*SSHPOP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
hostB, err := ioutil.ReadFile("./testdata/certs/ssh_host_ca_key.pub") hostB, err := os.ReadFile("./testdata/certs/ssh_host_ca_key.pub")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,14 +6,14 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
step "go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/linkedca" "go.step.sm/linkedca"
@ -238,6 +238,8 @@ func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error {
return nil return nil
} }
// CreateFirstProvisioner creates and stores the first provisioner when using
// admin database provisioner storage.
func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) { func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) {
if password == "" { if password == "" {
pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one") pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one")
@ -287,6 +289,7 @@ func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (
return p, nil return p, nil
} }
// ValidateClaims validates the Claims type.
func ValidateClaims(c *linkedca.Claims) error { func ValidateClaims(c *linkedca.Claims) error {
if c == nil { if c == nil {
return nil return nil
@ -313,6 +316,7 @@ func ValidateClaims(c *linkedca.Claims) error {
return nil return nil
} }
// ValidateDurations validates the Durations type.
func ValidateDurations(d *linkedca.Durations) error { func ValidateDurations(d *linkedca.Durations) error {
var ( var (
err error err error
@ -523,8 +527,8 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.X509.Template != "" { if p.X509.Template != "" {
x509Template.Template = []byte(p.SSH.Template) x509Template.Template = []byte(p.SSH.Template)
} else if p.X509.TemplateFile != "" { } else if p.X509.TemplateFile != "" {
filename := step.StepAbs(p.X509.TemplateFile) filename := step.Abs(p.X509.TemplateFile)
if x509Template.Template, err = ioutil.ReadFile(filename); err != nil { if x509Template.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading x509 template") return nil, nil, errors.Wrap(err, "error reading x509 template")
} }
} }
@ -539,8 +543,8 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.SSH.Template != "" { if p.SSH.Template != "" {
sshTemplate.Template = []byte(p.SSH.Template) sshTemplate.Template = []byte(p.SSH.Template)
} else if p.SSH.TemplateFile != "" { } else if p.SSH.TemplateFile != "" {
filename := step.StepAbs(p.SSH.TemplateFile) filename := step.Abs(p.SSH.TemplateFile)
if sshTemplate.Template, err = ioutil.ReadFile(filename); err != nil { if sshTemplate.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading ssh template") return nil, nil, errors.Wrap(err, "error reading ssh template")
} }
} }

View file

@ -101,6 +101,15 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Backwards compatibility for version of the cli older than v0.18.0.
// Before v0.18.0 we were not passing any value for SSHTemplateVersionKey
// from the cli.
if o.Name == "step_includes.tpl" && data[templates.SSHTemplateVersionKey] == "" {
o.Type = templates.File
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/")
}
output = append(output, o) output = append(output, o)
} }
return output, nil return output, nil

View file

@ -501,6 +501,32 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")}, {Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")},
} }
tmplConfigUserIncludes := &templates.Templates{
SSH: &templates.SSHTemplates{
User: []templates.Template{
{Name: "step_includes.tpl", Type: templates.PrependLine, TemplatePath: "./testdata/templates/step_includes.tpl", Path: "${STEPPATH}/ssh/includes", Comment: "#"},
},
},
Data: map[string]interface{}{
"Step": &templates.Step{
SSH: templates.StepSSH{
UserKey: user,
HostKey: host,
},
},
},
}
userOutputEmptyData := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"<no value>/ssh/config\"\n")},
}
userOutputWithoutTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
userOutputWithTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.PrependLine, Comment: "#", Path: "${STEPPATH}/ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
tmplConfigErr := &templates.Templates{ tmplConfigErr := &templates.Templates{
SSH: &templates.SSHTemplates{ SSH: &templates.SSHTemplates{
User: []templates.Template{ User: []templates.Template{
@ -542,6 +568,9 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false}, {"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false},
{"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false}, {"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false},
{"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false}, {"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false},
{"userIncludesEmptyData", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", nil}, userOutputEmptyData, false},
{"userIncludesWithoutTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithoutTemplateVersion, false},
{"userIncludesWithTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step", "StepSSHTemplateVersion": "v2"}}, userOutputWithTemplateVersion, false},
{"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true}, {"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true},
{"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true}, {"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true},
{"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true}, {"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true},

View file

@ -0,0 +1 @@
{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}

View file

@ -310,7 +310,7 @@ func TestAuthority_Sign(t *testing.T) {
extraOpts: extraOpts, extraOpts: extraOpts,
signOpts: _signOpts, signOpts: _signOpts,
err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"), err: errors.New("authority.Sign: requested duration of 25h0m0s is more than the authorized maximum certificate duration of 24h1m0s"),
code: http.StatusUnauthorized, code: http.StatusBadRequest,
} }
}, },
"fail validate sans when adding common name not in claims": func(t *testing.T) *signTest { "fail validate sans when adding common name not in claims": func(t *testing.T) *signTest {
@ -538,15 +538,15 @@ ZYtQ9Ot36qc=
if tc.csr.Subject.CommonName == "" { if tc.csr.Subject.CommonName == "" {
assert.Equals(t, leaf.Subject, pkix.Name{}) assert.Equals(t, leaf.Subject, pkix.Name{})
} else { } else {
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization}, Organization: []string{tmplt.Organization},
Locality: []string{tmplt.Locality}, Locality: []string{tmplt.Locality},
StreetAddress: []string{tmplt.StreetAddress}, StreetAddress: []string{tmplt.StreetAddress},
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: "smallstep test", CommonName: "smallstep test",
})) }.String())
assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"}) assert.Equals(t, leaf.DNSNames, []string{"test.smallstep.com"})
} }
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
@ -718,15 +718,15 @@ func TestAuthority_Renew(t *testing.T) {
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization}, Organization: []string{tmplt.Organization},
Locality: []string{tmplt.Locality}, Locality: []string{tmplt.Locality},
StreetAddress: []string{tmplt.StreetAddress}, StreetAddress: []string{tmplt.StreetAddress},
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: tmplt.CommonName, CommonName: tmplt.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
@ -925,15 +925,15 @@ func TestAuthority_Rekey(t *testing.T) {
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{tmplt.Country}, Country: []string{tmplt.Country},
Organization: []string{tmplt.Organization}, Organization: []string{tmplt.Organization},
Locality: []string{tmplt.Locality}, Locality: []string{tmplt.Locality},
StreetAddress: []string{tmplt.StreetAddress}, StreetAddress: []string{tmplt.StreetAddress},
Province: []string{tmplt.Province}, Province: []string{tmplt.Province},
CommonName: tmplt.CommonName, CommonName: tmplt.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)

View file

@ -7,7 +7,6 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"strings" "strings"
@ -292,7 +291,7 @@ func (c *ACMEClient) GetCertificate(url string) (*x509.Certificate, []*x509.Cert
return nil, nil, readACMEError(resp.Body) return nil, nil, readACMEError(resp.Body)
} }
defer resp.Body.Close() defer resp.Body.Close()
bodyBytes, err := ioutil.ReadAll(resp.Body) bodyBytes, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, nil, errors.Wrap(err, "error reading GET certificate response") return nil, nil, errors.Wrap(err, "error reading GET certificate response")
} }
@ -338,7 +337,7 @@ func (c *ACMEClient) GetAccountOrders() ([]string, error) {
func readACMEError(r io.ReadCloser) error { func readACMEError(r io.ReadCloser) error {
defer r.Close() defer r.Close()
b, err := ioutil.ReadAll(r) b, err := io.ReadAll(r)
if err != nil { if err != nil {
return errors.Wrap(err, "error reading from body") return errors.Wrap(err, "error reading from body")
} }

View file

@ -5,7 +5,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"io/ioutil" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -317,7 +317,7 @@ func TestACMEClient_post(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -455,7 +455,7 @@ func TestACMEClient_NewOrder(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -575,7 +575,7 @@ func TestACMEClient_GetOrder(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -695,7 +695,7 @@ func TestACMEClient_GetAuthz(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -815,7 +815,7 @@ func TestACMEClient_GetChallenge(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -936,7 +936,7 @@ func TestACMEClient_ValidateChallenge(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1061,7 +1061,7 @@ func TestACMEClient_FinalizeOrder(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1188,7 +1188,7 @@ func TestACMEClient_GetAccountOrders(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)
@ -1317,7 +1317,7 @@ func TestACMEClient_GetCertificate(t *testing.T) {
} }
// validate jws request protected headers and body // validate jws request protected headers and body
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(body)) jws, err := jose.ParseJWS(string(body))
assert.FatalError(t, err) assert.FatalError(t, err)

View file

@ -197,7 +197,7 @@ func (c *AdminClient) GetAdminsPaginate(opts ...AdminOption) (*adminAPI.GetAdmin
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error generating admin token") return nil, errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create GET %s request failed", u) return nil, errors.Wrapf(err, "create GET %s request failed", u)
} }
@ -284,7 +284,7 @@ func (c *AdminClient) RemoveAdmin(id string) error {
if err != nil { if err != nil {
return errors.Wrapf(err, "error generating admin token") return errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("DELETE", u.String(), nil) req, err := http.NewRequest("DELETE", u.String(), http.NoBody)
if err != nil { if err != nil {
return errors.Wrapf(err, "create DELETE %s request failed", u) return errors.Wrapf(err, "create DELETE %s request failed", u)
} }
@ -363,7 +363,7 @@ func (c *AdminClient) GetProvisioner(opts ...ProvisionerOption) (*linkedca.Provi
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error generating admin token") return nil, errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create PUT %s request failed", u) return nil, errors.Wrapf(err, "create PUT %s request failed", u)
} }
@ -402,7 +402,7 @@ func (c *AdminClient) GetProvisionersPaginate(opts ...ProvisionerOption) (*admin
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error generating admin token") return nil, errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("GET", u.String(), nil) req, err := http.NewRequest("GET", u.String(), http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "create PUT %s request failed", u) return nil, errors.Wrapf(err, "create PUT %s request failed", u)
} }
@ -472,7 +472,7 @@ func (c *AdminClient) RemoveProvisioner(opts ...ProvisionerOption) error {
if err != nil { if err != nil {
return errors.Wrapf(err, "error generating admin token") return errors.Wrapf(err, "error generating admin token")
} }
req, err := http.NewRequest("DELETE", u.String(), nil) req, err := http.NewRequest("DELETE", u.String(), http.NoBody)
if err != nil { if err != nil {
return errors.Wrapf(err, "create DELETE %s request failed", u) return errors.Wrapf(err, "create DELETE %s request failed", u)
} }

View file

@ -3,7 +3,7 @@ package ca
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"io/ioutil" "io"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -382,7 +382,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
return errors.Wrapf(err, "client.Get(%s) failed", srvURL) return errors.Wrapf(err, "client.Get(%s) failed", srvURL)
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errors.Wrap(err, "client.Get() error reading response") return errors.Wrap(err, "client.Get() error reading response")
} }
@ -499,7 +499,7 @@ func TestBootstrapClientServerFederation(t *testing.T) {
return errors.Wrapf(err, "client.Get(%s) failed", srvURL) return errors.Wrapf(err, "client.Get(%s) failed", srvURL)
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return errors.Wrap(err, "client.Get() error reading response") return errors.Wrap(err, "client.Get() error reading response")
} }
@ -589,9 +589,9 @@ func TestBootstrapListener(t *testing.T) {
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Errorf("ioutil.ReadAll() error = %v", err) t.Errorf("io.ReadAll() error = %v", err)
return return
} }
if string(b) != "ok" { if string(b) != "ok" {

View file

@ -294,15 +294,15 @@ ZEp7knvU2psWRw==
assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second)) assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second))
assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second)) assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second))
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
Country: []string{asn1dn.Country}, Country: []string{asn1dn.Country},
Organization: []string{asn1dn.Organization}, Organization: []string{asn1dn.Organization},
Locality: []string{asn1dn.Locality}, Locality: []string{asn1dn.Locality},
StreetAddress: []string{asn1dn.StreetAddress}, StreetAddress: []string{asn1dn.StreetAddress},
Province: []string{asn1dn.Province}, Province: []string{asn1dn.Province},
CommonName: asn1dn.CommonName, CommonName: asn1dn.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)
@ -641,10 +641,10 @@ func TestCARenew(t *testing.T) {
assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second)) assert.Equals(t, leaf.NotBefore, now.Truncate(time.Second))
assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second)) assert.Equals(t, leaf.NotAfter, leafExpiry.Truncate(time.Second))
assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), assert.Equals(t, leaf.Subject.String(),
fmt.Sprintf("%v", &pkix.Name{ pkix.Name{
CommonName: asn1dn.CommonName, CommonName: asn1dn.CommonName,
})) }.String())
assert.Equals(t, leaf.Issuer, intermediate.Subject) assert.Equals(t, leaf.Issuer, intermediate.Subject)
assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256) assert.Equals(t, leaf.SignatureAlgorithm, x509.ECDSAWithSHA256)

View file

@ -15,7 +15,6 @@ import (
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -29,7 +28,7 @@ import (
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/identity" "github.com/smallstep/certificates/ca/identity"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
@ -75,7 +74,7 @@ func (c *uaClient) SetTransport(tr http.RoundTripper) {
} }
func (c *uaClient) Get(u string) (*http.Response, error) { func (c *uaClient) Get(u string) (*http.Response, error) {
req, err := http.NewRequest("GET", u, nil) req, err := http.NewRequest("GET", u, http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "new request GET %s failed", u) return nil, errors.Wrapf(err, "new request GET %s failed", u)
} }
@ -226,7 +225,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
return tr, nil return tr, nil
} }
// WithTransport adds a custom transport to the Client. It will fail if a // WithTransport adds a custom transport to the Client. It will fail if a
// previous option to create the transport has been configured. // previous option to create the transport has been configured.
func WithTransport(tr http.RoundTripper) ClientOption { func WithTransport(tr http.RoundTripper) ClientOption {
return func(o *clientOptions) error { return func(o *clientOptions) error {
@ -238,6 +237,17 @@ func WithTransport(tr http.RoundTripper) ClientOption {
} }
} }
// WithInsecure adds a insecure transport that bypasses TLS verification.
func WithInsecure() ClientOption {
return func(o *clientOptions) error {
o.transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
return nil
}
}
// WithRootFile will create the transport using the given root certificate. It // WithRootFile will create the transport using the given root certificate. It
// will fail if a previous option to create the transport has been configured. // will fail if a previous option to create the transport has been configured.
func WithRootFile(filename string) ClientOption { func WithRootFile(filename string) ClientOption {
@ -350,7 +360,7 @@ func WithRetryFunc(fn RetryFunc) ClientOption {
} }
func getTransportFromFile(filename string) (http.RoundTripper, error) { func getTransportFromFile(filename string) (http.RoundTripper, error) {
data, err := ioutil.ReadFile(filename) data, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename) return nil, errors.Wrapf(err, "error reading %s", filename)
} }
@ -1295,7 +1305,7 @@ func createCertificateRequest(commonName string, sans []string, key crypto.Priva
// getRootCAPath returns the path where the root CA is stored based on the // getRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable. // STEPPATH environment variable.
func getRootCAPath() string { func getRootCAPath() string {
return filepath.Join(config.StepPath(), "certs", "root_ca.crt") return filepath.Join(step.Path(), "certs", "root_ca.crt")
} }
func readJSON(r io.ReadCloser, v interface{}) error { func readJSON(r io.ReadCloser, v interface{}) error {
@ -1305,7 +1315,7 @@ func readJSON(r io.ReadCloser, v interface{}) error {
func readProtoJSON(r io.ReadCloser, m proto.Message) error { func readProtoJSON(r io.ReadCloser, m proto.Message) error {
defer r.Close() defer r.Close()
data, err := ioutil.ReadAll(r) data, err := io.ReadAll(r)
if err != nil { if err != nil {
return err return err
} }

View file

@ -5,9 +5,9 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@ -27,21 +27,22 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL {
// $STEPPATH/config/defaults.json and the identity defined in // $STEPPATH/config/defaults.json and the identity defined in
// $STEPPATH/config/identity.json // $STEPPATH/config/identity.json
func LoadClient() (*Client, error) { func LoadClient() (*Client, error) {
b, err := ioutil.ReadFile(DefaultsFile) defaultsFile := DefaultsFile()
b, err := os.ReadFile(defaultsFile)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", DefaultsFile) return nil, errors.Wrapf(err, "error reading %s", defaultsFile)
} }
var defaults defaultsConfig var defaults defaultsConfig
if err := json.Unmarshal(b, &defaults); err != nil { if err := json.Unmarshal(b, &defaults); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", DefaultsFile) return nil, errors.Wrapf(err, "error unmarshaling %s", defaultsFile)
} }
if err := defaults.Validate(); err != nil { if err := defaults.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
caURL, err := url.Parse(defaults.CaURL) caURL, err := url.Parse(defaults.CaURL)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
if caURL.Scheme == "" { if caURL.Scheme == "" {
caURL.Scheme = "https" caURL.Scheme = "https"
@ -52,7 +53,7 @@ func LoadClient() (*Client, error) {
return nil, err return nil, err
} }
if err := identity.Validate(); err != nil { if err := identity.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", IdentityFile) return nil, errors.Wrapf(err, "error validating %s", IdentityFile())
} }
if kind := identity.Kind(); kind != MutualTLS { if kind := identity.Kind(); kind != MutualTLS {
return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind) return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind)
@ -65,7 +66,7 @@ func LoadClient() (*Client, error) {
} }
// RootCAs // RootCAs
b, err = ioutil.ReadFile(defaults.Root) b, err = os.ReadFile(defaults.Root)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error loading %s", defaults.Root) return nil, errors.Wrapf(err, "error loading %s", defaults.Root)
} }

View file

@ -3,14 +3,20 @@ package identity
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"os"
"reflect" "reflect"
"testing" "testing"
) )
func returnInput(val string) func() string {
return func() string {
return val
}
}
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
oldIdentityFile := IdentityFile oldIdentityFile := IdentityFile
oldDefaultsFile := DefaultsFile oldDefaultsFile := DefaultsFile
@ -19,8 +25,8 @@ func TestClient(t *testing.T) {
DefaultsFile = oldDefaultsFile DefaultsFile = oldDefaultsFile
}() }()
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
client, err := LoadClient() client, err := LoadClient()
if err != nil { if err != nil {
@ -40,7 +46,7 @@ func TestClient(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err := ioutil.ReadFile("testdata/certs/root_ca.crt") b, err := os.ReadFile("testdata/certs/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -114,7 +120,7 @@ func TestLoadClient(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err := ioutil.ReadFile("testdata/certs/root_ca.crt") b, err := os.ReadFile("testdata/certs/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -140,36 +146,36 @@ func TestLoadClient(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{"ok", func() { {"ok", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, expected, false}, }, expected, false},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/missing.json" IdentityFile = returnInput("testdata/config/missing.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/fail.json" IdentityFile = returnInput("testdata/config/fail.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/missing.json" DefaultsFile = returnInput("testdata/config/missing.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/fail.json" DefaultsFile = returnInput("testdata/config/fail.json")
}, nil, true}, }, nil, true},
{"fail ca", func() { {"fail ca", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badca.json" DefaultsFile = returnInput("testdata/config/badca.json")
}, nil, true}, }, nil, true},
{"fail root", func() { {"fail root", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badroot.json" DefaultsFile = returnInput("testdata/config/badroot.json")
}, nil, true}, }, nil, true},
{"fail type", func() { {"fail type", func() {
IdentityFile = "testdata/config/badIdentity.json" IdentityFile = returnInput("testdata/config/badIdentity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -7,7 +7,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"io/ioutil"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -16,7 +15,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
@ -39,11 +38,18 @@ const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims. // DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute const DefaultLeeway = 1 * time.Minute
// IdentityFile contains the location of the identity file. var (
var IdentityFile = filepath.Join(config.StepPath(), "config", "identity.json") identityDir = step.IdentityPath
configDir = step.ConfigPath
// DefaultsFile contains the location of the defaults file. // IdentityFile contains a pointer to a function that outputs the location of
var DefaultsFile = filepath.Join(config.StepPath(), "config", "defaults.json") // the identity file.
IdentityFile = step.IdentityFile
// DefaultsFile contains a prointer a function that outputs the location of the
// defaults configuration file.
DefaultsFile = step.DefaultsFile
)
// Identity represents the identity file that can be used to authenticate with // Identity represents the identity file that can be used to authenticate with
// the CA. // the CA.
@ -61,7 +67,7 @@ type Identity struct {
// LoadIdentity loads an identity present in the given filename. // LoadIdentity loads an identity present in the given filename.
func LoadIdentity(filename string) (*Identity, error) { func LoadIdentity(filename string) (*Identity, error) {
b, err := ioutil.ReadFile(filename) b, err := os.ReadFile(filename)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", filename) return nil, errors.Wrapf(err, "error reading %s", filename)
} }
@ -74,23 +80,17 @@ func LoadIdentity(filename string) (*Identity, error) {
// LoadDefaultIdentity loads the default identity. // LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) { func LoadDefaultIdentity() (*Identity, error) {
return LoadIdentity(IdentityFile) return LoadIdentity(IdentityFile())
} }
// configDir and identityDir are used in WriteDefaultIdentity for testing
// purposes.
var (
configDir = filepath.Join(config.StepPath(), "config")
identityDir = filepath.Join(config.StepPath(), "identity")
)
// WriteDefaultIdentity writes the given certificates and key and the // WriteDefaultIdentity writes the given certificates and key and the
// identity.json pointing to the new files. // identity.json pointing to the new files.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
if err := os.MkdirAll(configDir, 0700); err != nil { if err := os.MkdirAll(configDir(), 0700); err != nil {
return errors.Wrap(err, "error creating config directory") return errors.Wrap(err, "error creating config directory")
} }
identityDir := identityDir()
if err := os.MkdirAll(identityDir, 0700); err != nil { if err := os.MkdirAll(identityDir, 0700); err != nil {
return errors.Wrap(err, "error creating identity directory") return errors.Wrap(err, "error creating identity directory")
} }
@ -112,7 +112,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
if err := pem.Encode(buf, block); err != nil { if err := pem.Encode(buf, block); err != nil {
return errors.Wrap(err, "error encoding identity key") return errors.Wrap(err, "error encoding identity key")
} }
if err := ioutil.WriteFile(keyFilename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(keyFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }
@ -127,7 +127,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
}); err != nil { }); err != nil {
return errors.Wrap(err, "error writing identity json") return errors.Wrap(err, "error writing identity json")
} }
if err := ioutil.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }
@ -136,7 +136,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
// WriteIdentityCertificate writes the identity certificate to disk. // WriteIdentityCertificate writes the identity certificate to disk.
func WriteIdentityCertificate(certChain []api.Certificate) error { func WriteIdentityCertificate(certChain []api.Certificate) error {
filename := filepath.Join(identityDir, "identity.crt") filename := filepath.Join(identityDir(), "identity.crt")
return writeCertificate(filename, certChain) return writeCertificate(filename, certChain)
} }
@ -153,7 +153,7 @@ func writeCertificate(filename string, certChain []api.Certificate) error {
} }
} }
if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(filename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing certificate") return errors.Wrap(err, "error writing certificate")
} }
@ -263,7 +263,7 @@ func (i *Identity) GetCertPool() (*x509.CertPool, error) {
if i.Root == "" { if i.Root == "" {
return nil, nil return nil, nil
} }
b, err := ioutil.ReadFile(i.Root) b, err := os.ReadFile(i.Root)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error reading identity root") return nil, errors.Wrap(err, "error reading identity root")
} }
@ -319,8 +319,8 @@ func (i *Identity) Renew(client Renewer) error {
return errors.Wrap(err, "error encoding identity certificate") return errors.Wrap(err, "error encoding identity certificate")
} }
} }
certFilename := filepath.Join(identityDir, "identity.crt") certFilename := filepath.Join(identityDir(), "identity.crt")
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }

View file

@ -33,9 +33,9 @@ func TestLoadDefaultIdentity(t *testing.T) {
want *Identity want *Identity
wantErr bool wantErr bool
}{ }{
{"ok", func() { IdentityFile = "testdata/config/identity.json" }, expected, false}, {"ok", func() { IdentityFile = returnInput("testdata/config/identity.json") }, expected, false},
{"fail read", func() { IdentityFile = "testdata/config/missing.json" }, nil, true}, {"fail read", func() { IdentityFile = returnInput("testdata/config/missing.json") }, nil, true},
{"fail unmarshal", func() { IdentityFile = "testdata/config/fail.json" }, nil, true}, {"fail unmarshal", func() { IdentityFile = returnInput("testdata/config/fail.json") }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -217,9 +217,9 @@ func TestWriteDefaultIdentity(t *testing.T) {
certChain = append(certChain, api.Certificate{Certificate: c}) certChain = append(certChain, api.Certificate{Certificate: c})
} }
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(tmpDir, "config", "identity.json") IdentityFile = returnInput(filepath.Join(tmpDir, "config", "identity.json"))
type args struct { type args struct {
certChain []api.Certificate certChain []api.Certificate
@ -233,27 +233,27 @@ func TestWriteDefaultIdentity(t *testing.T) {
}{ }{
{"ok", func() {}, args{certChain, key}, false}, {"ok", func() {}, args{certChain, key}, false},
{"fail mkdir config", func() { {"fail mkdir config", func() {
configDir = filepath.Join(tmpDir, "identity", "identity.crt") configDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail mkdir identity", func() { {"fail mkdir identity", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity", "identity.crt") identityDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail certificate", func() { {"fail certificate", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail key", func() { {"fail key", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, "badKey"}, true}, }, args{certChain, "badKey"}, true},
{"fail write identity", func() { {"fail write identity", func() {
configDir = filepath.Join(tmpDir, "bad-dir") configDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(configDir, "identity.json") IdentityFile = returnInput(filepath.Join(configDir(), "identity.json"))
os.MkdirAll(configDir, 0600) os.MkdirAll(configDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
} }
@ -377,7 +377,7 @@ func TestIdentity_Renew(t *testing.T) {
} }
oldIdentityDir := identityDir oldIdentityDir := identityDir
identityDir = "testdata/identity" identityDir = returnInput("testdata/identity")
defer func() { defer func() {
identityDir = oldIdentityDir identityDir = oldIdentityDir
os.RemoveAll(tmpDir) os.RemoveAll(tmpDir)
@ -432,8 +432,8 @@ func TestIdentity_Renew(t *testing.T) {
{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true}, {"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true}, {"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
{"fail write identity", func() { {"fail write identity", func() {
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true}, }, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -1,8 +1,8 @@
package ca package ca
import ( import (
"io/ioutil"
"net/url" "net/url"
"os"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -45,7 +45,7 @@ func TestNewProvisioner(t *testing.T) {
defer ca.Close() defer ca.Close()
want := getTestProvisioner(t, ca.URL) want := getTestProvisioner(t, ca.URL)
caBundle, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") caBundle, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -18,7 +18,7 @@ var minCertDuration = time.Minute
// TLSRenewer automatically renews a tls certificate using a RenewFunc. // TLSRenewer automatically renews a tls certificate using a RenewFunc.
type TLSRenewer struct { type TLSRenewer struct {
sync.RWMutex renewMutex sync.RWMutex
RenewCertificate RenewFunc RenewCertificate RenewFunc
cert *tls.Certificate cert *tls.Certificate
timer *time.Timer timer *time.Timer
@ -81,9 +81,9 @@ func NewTLSRenewer(cert *tls.Certificate, fn RenewFunc, opts ...tlsRenewerOption
func (r *TLSRenewer) Run() { func (r *TLSRenewer) Run() {
cert := r.getCertificate() cert := r.getCertificate()
next := r.nextRenewDuration(cert.Leaf.NotAfter) next := r.nextRenewDuration(cert.Leaf.NotAfter)
r.Lock() r.renewMutex.Lock()
r.timer = time.AfterFunc(next, r.renewCertificate) r.timer = time.AfterFunc(next, r.renewCertificate)
r.Unlock() r.renewMutex.Unlock()
} }
// RunContext starts the certificate renewer for the given certificate. // RunContext starts the certificate renewer for the given certificate.
@ -133,25 +133,25 @@ func (r *TLSRenewer) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Cer
// if the timer does not fire e.g. when the CA is run from a laptop that // if the timer does not fire e.g. when the CA is run from a laptop that
// enters sleep mode. // enters sleep mode.
func (r *TLSRenewer) getCertificate() *tls.Certificate { func (r *TLSRenewer) getCertificate() *tls.Certificate {
r.RLock() r.renewMutex.RLock()
cert := r.cert cert := r.cert
r.RUnlock() r.renewMutex.RUnlock()
return cert return cert
} }
// getCertificateForCA returns the certificate using a read-only lock. It will // getCertificateForCA returns the certificate using a read-only lock. It will
// automatically renew the certificate if it has expired. // automatically renew the certificate if it has expired.
func (r *TLSRenewer) getCertificateForCA() *tls.Certificate { func (r *TLSRenewer) getCertificateForCA() *tls.Certificate {
r.RLock() r.renewMutex.RLock()
// Force certificate renewal if the timer didn't run. // Force certificate renewal if the timer didn't run.
// This is an special case that can happen after a computer sleep. // This is an special case that can happen after a computer sleep.
if time.Now().After(r.certNotAfter) { if time.Now().After(r.certNotAfter) {
r.RUnlock() r.renewMutex.RUnlock()
r.renewCertificate() r.renewCertificate()
r.RLock() r.renewMutex.RLock()
} }
cert := r.cert cert := r.cert
r.RUnlock() r.renewMutex.RUnlock()
return cert return cert
} }
@ -159,10 +159,10 @@ func (r *TLSRenewer) getCertificateForCA() *tls.Certificate {
// updates certNotAfter with 1m of delta; this will force the renewal of the // updates certNotAfter with 1m of delta; this will force the renewal of the
// certificate if it is about to expire. // certificate if it is about to expire.
func (r *TLSRenewer) setCertificate(cert *tls.Certificate) { func (r *TLSRenewer) setCertificate(cert *tls.Certificate) {
r.Lock() r.renewMutex.Lock()
r.cert = cert r.cert = cert
r.certNotAfter = cert.Leaf.NotAfter.Add(-1 * time.Minute) r.certNotAfter = cert.Leaf.NotAfter.Add(-1 * time.Minute)
r.Unlock() r.renewMutex.Unlock()
} }
func (r *TLSRenewer) renewCertificate() { func (r *TLSRenewer) renewCertificate() {
@ -175,9 +175,9 @@ func (r *TLSRenewer) renewCertificate() {
r.setCertificate(cert) r.setCertificate(cert)
next = r.nextRenewDuration(cert.Leaf.NotAfter) next = r.nextRenewDuration(cert.Leaf.NotAfter)
} }
r.Lock() r.renewMutex.Lock()
r.timer.Reset(next) r.timer.Reset(next)
r.Unlock() r.renewMutex.Unlock()
} }
func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration { func (r *TLSRenewer) nextRenewDuration(notAfter time.Time) time.Duration {

View file

@ -4,8 +4,8 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"os"
"reflect" "reflect"
"sort" "sort"
"testing" "testing"
@ -202,7 +202,7 @@ func TestAddRootsToRootCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -256,7 +256,7 @@ func TestAddRootsToClientCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -310,12 +310,12 @@ func TestAddFederationToRootCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -374,12 +374,12 @@ func TestAddFederationToClientCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -438,7 +438,7 @@ func TestAddRootsToCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -492,12 +492,12 @@ func TestAddFederationToCAs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
root, err := ioutil.ReadFile("testdata/secrets/root_ca.crt") root, err := os.ReadFile("testdata/secrets/root_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
federated, err := ioutil.ReadFile("testdata/secrets/federated_ca.crt") federated, err := os.ReadFile("testdata/secrets/federated_ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -8,7 +8,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"io/ioutil" "io"
"log" "log"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -221,7 +221,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) {
return return
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Fatalf("ioutil.RealAdd() error = %v", err) t.Fatalf("ioutil.RealAdd() error = %v", err)
} }
@ -335,7 +335,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err) t.Errorf("ioutil.RealAdd() error = %v", err)
return return
@ -374,9 +374,9 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) {
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
t.Errorf("ioutil.RealAdd() error = %v", err) t.Errorf("io.ReadAll() error = %v", err)
return return
} }
if !bytes.Equal(b, []byte("ok")) { if !bytes.Equal(b, []byte("ok")) {

View file

@ -91,7 +91,7 @@ func mustSerializeCrt(filename string, certs ...*x509.Certificate) {
panic(err) panic(err)
} }
} }
if err := ioutil.WriteFile(filename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(filename, buf.Bytes(), 0600); err != nil {
panic(err) panic(err)
} }
} }
@ -105,7 +105,7 @@ func mustSerializeKey(filename string, key crypto.Signer) {
Type: "PRIVATE KEY", Type: "PRIVATE KEY",
Bytes: b, Bytes: b,
}) })
if err := ioutil.WriteFile(filename, b, 0600); err != nil { if err := os.WriteFile(filename, b, 0600); err != nil {
panic(err) panic(err)
} }
} }

View file

@ -21,7 +21,7 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
"go.step.sm/cli-utils/command" "go.step.sm/cli-utils/command"
"go.step.sm/cli-utils/command/version" "go.step.sm/cli-utils/command/version"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/cli-utils/usage" "go.step.sm/cli-utils/usage"
@ -49,7 +49,7 @@ var (
) )
func init() { func init() {
config.Set("Smallstep CA", Version, BuildTime) step.Set("Smallstep CA", Version, BuildTime)
authority.GlobalVersion.Version = Version authority.GlobalVersion.Version = Version
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
} }
@ -115,7 +115,7 @@ func main() {
app := cli.NewApp() app := cli.NewApp()
app.Name = "step-ca" app.Name = "step-ca"
app.HelpName = "step-ca" app.HelpName = "step-ca"
app.Version = config.Version() app.Version = step.Version()
app.Usage = "an online certificate authority for secure automated certificate management" app.Usage = "an online certificate authority for secure automated certificate management"
app.UsageText = `**step-ca** <config> [**--password-file**=<file>] app.UsageText = `**step-ca** <config> [**--password-file**=<file>]
[**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>] [**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>]

View file

@ -32,7 +32,7 @@ import (
// Config is a mapping of the cli flags. // Config is a mapping of the cli flags.
type Config struct { type Config struct {
KMS string KMS string
RootOnly bool GenerateRoot bool
RootObject string RootObject string
RootKeyObject string RootKeyObject string
RootSubject string RootSubject string
@ -58,22 +58,29 @@ func (c *Config) Validate() error {
switch { switch {
case c.KMS == "": case c.KMS == "":
return errors.New("flag `--kms` is required") return errors.New("flag `--kms` is required")
case c.CrtPath == "":
return errors.New("flag `--crt-cert-path` is required")
case c.RootFile != "" && c.KeyFile == "": case c.RootFile != "" && c.KeyFile == "":
return errors.New("flag `--root` requires flag `--key`") return errors.New("flag `--root-cert-file` requires flag `--root-key-file`")
case c.KeyFile != "" && c.RootFile == "": case c.KeyFile != "" && c.RootFile == "":
return errors.New("flag `--key` requires flag `--root`") return errors.New("flag `--root-key-file` requires flag `--root-cert-file`")
case c.RootOnly && c.RootFile != "":
return errors.New("flag `--root-only` is incompatible with flag `--root`")
case c.RootFile == "" && c.RootObject == "": case c.RootFile == "" && c.RootObject == "":
return errors.New("one of flag `--root` or `--root-cert` is required") return errors.New("one of flag `--root-cert-file` or `--root-cert-obj` is required")
case c.RootFile == "" && c.RootKeyObject == "": case c.KeyFile == "" && c.RootKeyObject == "":
return errors.New("one of flag `--root` or `--root-key` is required") return errors.New("one of flag `--root-key-file` or `--root-key-obj` is required")
case c.CrtKeyPath == "" && c.CrtKeyObject == "":
return errors.New("one of flag `--crt-key-path` or `--crt-key-obj` is required")
case c.RootFile == "" && c.GenerateRoot && c.RootKeyObject == "":
return errors.New("flag `--root-gen` requires flag `--root-key-obj`")
case c.RootFile == "" && c.GenerateRoot && c.RootPath == "":
return errors.New("flag `--root-gen` requires `--root-cert-path`")
default: default:
if c.RootFile != "" { if c.RootFile != "" {
c.GenerateRoot = false
c.RootObject = "" c.RootObject = ""
c.RootKeyObject = "" c.RootKeyObject = ""
} }
if c.RootOnly { if c.CrtKeyPath != "" {
c.CrtObject = "" c.CrtObject = ""
c.CrtKeyObject = "" c.CrtKeyObject = ""
} }
@ -101,21 +108,27 @@ func main() {
var c Config var c Config
flag.StringVar(&c.KMS, "kms", kmsuri, "PKCS #11 URI with the module-path and token to connect to the module.") flag.StringVar(&c.KMS, "kms", kmsuri, "PKCS #11 URI with the module-path and token to connect to the module.")
flag.StringVar(&c.Pin, "pin", "", "PKCS #11 PIN") flag.StringVar(&c.Pin, "pin", "", "PKCS #11 PIN")
flag.StringVar(&c.RootObject, "root-cert", "pkcs11:id=7330;object=root-cert", "PKCS #11 URI with object id and label to store the root certificate.") // Option 1: Generate new root
flag.StringVar(&c.RootPath, "root-cert-path", "root_ca.crt", "Location to write the root certificate.") flag.BoolVar(&c.GenerateRoot, "root-gen", true, "Enable the generation of a root key.")
flag.StringVar(&c.RootKeyObject, "root-key", "pkcs11:id=7330;object=root-key", "PKCS #11 URI with object id and label to store the root key.")
flag.StringVar(&c.RootSubject, "root-name", "PKCS #11 Smallstep Root", "Subject and Issuer of the root certificate.") flag.StringVar(&c.RootSubject, "root-name", "PKCS #11 Smallstep Root", "Subject and Issuer of the root certificate.")
flag.StringVar(&c.CrtObject, "crt-cert", "pkcs11:id=7331;object=intermediate-cert", "PKCS #11 URI with object id and label to store the intermediate certificate.") flag.StringVar(&c.RootObject, "root-cert-obj", "pkcs11:id=7330;object=root-cert", "PKCS #11 URI with object id and label to store the root certificate.")
flag.StringVar(&c.CrtPath, "crt-cert-path", "intermediate_ca.crt", "Location to write the intermediate certificate.") flag.StringVar(&c.RootKeyObject, "root-key-obj", "pkcs11:id=7330;object=root-key", "PKCS #11 URI with object id and label to store the root key.")
flag.StringVar(&c.CrtKeyObject, "crt-key", "pkcs11:id=7331;object=intermediate-key", "PKCS #11 URI with object id and label to store the intermediate certificate.") // Option 2: Read root from disk and sign intermediate
flag.StringVar(&c.RootFile, "root-cert-file", "", "Path to the root certificate to use.")
flag.StringVar(&c.KeyFile, "root-key-file", "", "Path to the root key to use.")
// Option 3: Generate certificate signing request
flag.StringVar(&c.CrtSubject, "crt-name", "PKCS #11 Smallstep Intermediate", "Subject of the intermediate certificate.") flag.StringVar(&c.CrtSubject, "crt-name", "PKCS #11 Smallstep Intermediate", "Subject of the intermediate certificate.")
flag.StringVar(&c.CrtKeyPath, "crt-key-path", "intermediate_ca_key", "Location to write the intermediate private key.") flag.StringVar(&c.CrtObject, "crt-cert-obj", "pkcs11:id=7331;object=intermediate-cert", "PKCS #11 URI with object id and label to store the intermediate certificate.")
flag.StringVar(&c.CrtKeyObject, "crt-key-obj", "pkcs11:id=7331;object=intermediate-key", "PKCS #11 URI with object id and label to store the intermediate certificate.")
// SSH certificates
flag.BoolVar(&c.EnableSSH, "ssh", false, "Enable the creation of ssh keys.")
flag.StringVar(&c.SSHHostKeyObject, "ssh-host-key", "pkcs11:id=7332;object=ssh-host-key", "PKCS #11 URI with object id and label to store the key used to sign SSH host certificates.") flag.StringVar(&c.SSHHostKeyObject, "ssh-host-key", "pkcs11:id=7332;object=ssh-host-key", "PKCS #11 URI with object id and label to store the key used to sign SSH host certificates.")
flag.StringVar(&c.SSHUserKeyObject, "ssh-user-key", "pkcs11:id=7333;object=ssh-user-key", "PKCS #11 URI with object id and label to store the key used to sign SSH user certificates.") flag.StringVar(&c.SSHUserKeyObject, "ssh-user-key", "pkcs11:id=7333;object=ssh-user-key", "PKCS #11 URI with object id and label to store the key used to sign SSH user certificates.")
flag.BoolVar(&c.RootOnly, "root-only", false, "Store only only the root certificate and sign and intermediate.") // Output files
flag.StringVar(&c.RootFile, "root", "", "Path to the root certificate to use.") flag.StringVar(&c.RootPath, "root-cert-path", "root_ca.crt", "Location to write the root certificate.")
flag.StringVar(&c.KeyFile, "key", "", "Path to the root key to use.") flag.StringVar(&c.CrtPath, "crt-cert-path", "intermediate_ca.crt", "Location to write the intermediate certificate.")
flag.BoolVar(&c.EnableSSH, "ssh", false, "Enable the creation of ssh keys.") flag.StringVar(&c.CrtKeyPath, "crt-key-path", "", "Location to write the intermediate private key.")
// Others
flag.BoolVar(&c.NoCerts, "no-certs", false, "Do not store certificates in the module.") flag.BoolVar(&c.NoCerts, "no-certs", false, "Do not store certificates in the module.")
flag.BoolVar(&c.Force, "force", false, "Force the delete of previous keys.") flag.BoolVar(&c.Force, "force", false, "Force the delete of previous keys.")
flag.BoolVar(&c.Extractable, "extractable", false, "Allow export of private keys under wrap.") flag.BoolVar(&c.Extractable, "extractable", false, "Allow export of private keys under wrap.")
@ -276,22 +289,8 @@ func createPKI(k kms.KeyManager, c Config) error {
// Root Certificate // Root Certificate
var signer crypto.Signer var signer crypto.Signer
var root *x509.Certificate var root *x509.Certificate
if c.RootFile != "" && c.KeyFile != "" { switch {
root, err = pemutil.ReadCertificate(c.RootFile) case c.GenerateRoot:
if err != nil {
return err
}
key, err := pemutil.Read(c.KeyFile)
if err != nil {
return err
}
var ok bool
if signer, ok = key.(crypto.Signer); !ok {
return errors.Errorf("key type '%T' does not implement a signer", key)
}
} else {
resp, err := k.CreateKey(&apiv1.CreateKeyRequest{ resp, err := k.CreateKey(&apiv1.CreateKeyRequest{
Name: c.RootKeyObject, Name: c.RootKeyObject,
SignatureAlgorithm: apiv1.ECDSAWithSHA256, SignatureAlgorithm: apiv1.ECDSAWithSHA256,
@ -331,7 +330,7 @@ func createPKI(k kms.KeyManager, c Config) error {
return errors.Wrap(err, "error parsing root certificate") return errors.Wrap(err, "error parsing root certificate")
} }
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts { if cm, ok := k.(kms.CertificateManager); ok && c.RootObject != "" && !c.NoCerts {
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{ if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.RootObject, Name: c.RootObject,
Certificate: root, Certificate: root,
@ -339,6 +338,8 @@ func createPKI(k kms.KeyManager, c Config) error {
}); err != nil { }); err != nil {
return err return err
} }
} else {
c.RootObject = ""
} }
if err := fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{ if err := fileutil.WriteFile(c.RootPath, pem.EncodeToMemory(&pem.Block{
@ -350,12 +351,31 @@ func createPKI(k kms.KeyManager, c Config) error {
ui.PrintSelected("Root Key", resp.Name) ui.PrintSelected("Root Key", resp.Name)
ui.PrintSelected("Root Certificate", c.RootPath) ui.PrintSelected("Root Certificate", c.RootPath)
if c.RootObject != "" {
ui.PrintSelected("Root Certificate Object", c.RootObject)
}
case c.RootFile != "" && c.KeyFile != "": // Read Root From File
root, err = pemutil.ReadCertificate(c.RootFile)
if err != nil {
return err
}
key, err := pemutil.Read(c.KeyFile)
if err != nil {
return err
}
var ok bool
if signer, ok = key.(crypto.Signer); !ok {
return errors.Errorf("key type '%T' does not implement a signer", key)
}
} }
// Intermediate Certificate // Intermediate Certificate
var keyName string var keyName string
var publicKey crypto.PublicKey var publicKey crypto.PublicKey
if c.RootOnly { var intSigner crypto.Signer
if c.CrtKeyPath != "" {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
return errors.Wrap(err, "error creating intermediate key") return errors.Wrap(err, "error creating intermediate key")
@ -373,6 +393,7 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
publicKey = priv.Public() publicKey = priv.Public()
intSigner = priv
} else { } else {
resp, err := k.CreateKey(&apiv1.CreateKeyRequest{ resp, err := k.CreateKey(&apiv1.CreateKeyRequest{
Name: c.CrtKeyObject, Name: c.CrtKeyObject,
@ -384,56 +405,89 @@ func createPKI(k kms.KeyManager, c Config) error {
} }
publicKey = resp.PublicKey publicKey = resp.PublicKey
keyName = resp.Name keyName = resp.Name
}
template := &x509.Certificate{ intSigner, err = k.CreateSigner(&resp.CreateSignerRequest)
IsCA: true, if err != nil {
NotBefore: now,
NotAfter: now.Add(time.Hour * 24 * 365 * 10),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
MaxPathLen: 0,
MaxPathLenZero: true,
Issuer: root.Subject,
Subject: pkix.Name{CommonName: c.CrtSubject},
SerialNumber: mustSerialNumber(),
SubjectKeyId: mustSubjectKeyID(publicKey),
}
b, err := x509.CreateCertificate(rand.Reader, template, root, publicKey, signer)
if err != nil {
return err
}
intermediate, err := x509.ParseCertificate(b)
if err != nil {
return errors.Wrap(err, "error parsing intermediate certificate")
}
if cm, ok := k.(kms.CertificateManager); ok && !c.NoCerts {
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.CrtObject,
Certificate: intermediate,
Extractable: c.Extractable,
}); err != nil {
return err return err
} }
} }
if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{ if root != nil {
Type: "CERTIFICATE", template := &x509.Certificate{
Bytes: b, IsCA: true,
}), 0600); err != nil { NotBefore: now,
return err NotAfter: now.Add(time.Hour * 24 * 365 * 10),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
MaxPathLen: 0,
MaxPathLenZero: true,
Issuer: root.Subject,
Subject: pkix.Name{CommonName: c.CrtSubject},
SerialNumber: mustSerialNumber(),
SubjectKeyId: mustSubjectKeyID(publicKey),
}
b, err := x509.CreateCertificate(rand.Reader, template, root, publicKey, signer)
if err != nil {
return err
}
intermediate, err := x509.ParseCertificate(b)
if err != nil {
return errors.Wrap(err, "error parsing intermediate certificate")
}
if cm, ok := k.(kms.CertificateManager); ok && c.CrtObject != "" && !c.NoCerts {
if err := cm.StoreCertificate(&apiv1.StoreCertificateRequest{
Name: c.CrtObject,
Certificate: intermediate,
Extractable: c.Extractable,
}); err != nil {
return err
}
} else {
c.CrtObject = ""
}
if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: b,
}), 0600); err != nil {
return err
}
} else {
// No root available, generate CSR for external root.
csrTemplate := x509.CertificateRequest{
Subject: pkix.Name{CommonName: c.CrtSubject},
SignatureAlgorithm: x509.ECDSAWithSHA256,
}
// step: generate the csr request
csrCertificate, err := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, intSigner)
if err != nil {
return err
}
if err := fileutil.WriteFile(c.CrtPath, pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE REQUEST",
Bytes: csrCertificate,
}), 0600); err != nil {
return err
}
} }
if c.RootOnly { if c.CrtKeyPath != "" {
ui.PrintSelected("Intermediate Key", c.CrtKeyPath) ui.PrintSelected("Intermediate Key", c.CrtKeyPath)
} else { } else {
ui.PrintSelected("Intermediate Key", keyName) ui.PrintSelected("Intermediate Key", keyName)
} }
ui.PrintSelected("Intermediate Certificate", c.CrtPath) if root != nil {
ui.PrintSelected("Intermediate Certificate", c.CrtPath)
if c.CrtObject != "" {
ui.PrintSelected("Intermediate Certificate Object", c.CrtObject)
}
} else {
ui.PrintSelected("Intermediate Certificate Request", c.CrtPath)
}
if c.SSHHostKeyObject != "" { if c.SSHHostKeyObject != "" {
resp, err := k.CreateKey(&apiv1.CreateKeyRequest{ resp, err := k.CreateKey(&apiv1.CreateKeyRequest{

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -98,7 +97,7 @@ To get a linked authority token:
var password []byte var password []byte
if passFile != "" { if passFile != "" {
if password, err = ioutil.ReadFile(passFile); err != nil { if password, err = os.ReadFile(passFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", passFile)) fatal(errors.Wrapf(err, "error reading %s", passFile))
} }
password = bytes.TrimRightFunc(password, unicode.IsSpace) password = bytes.TrimRightFunc(password, unicode.IsSpace)
@ -106,7 +105,7 @@ To get a linked authority token:
var sshHostPassword []byte var sshHostPassword []byte
if sshHostPassFile != "" { if sshHostPassFile != "" {
if sshHostPassword, err = ioutil.ReadFile(sshHostPassFile); err != nil { if sshHostPassword, err = os.ReadFile(sshHostPassFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", sshHostPassFile)) fatal(errors.Wrapf(err, "error reading %s", sshHostPassFile))
} }
sshHostPassword = bytes.TrimRightFunc(sshHostPassword, unicode.IsSpace) sshHostPassword = bytes.TrimRightFunc(sshHostPassword, unicode.IsSpace)
@ -114,7 +113,7 @@ To get a linked authority token:
var sshUserPassword []byte var sshUserPassword []byte
if sshUserPassFile != "" { if sshUserPassFile != "" {
if sshUserPassword, err = ioutil.ReadFile(sshUserPassFile); err != nil { if sshUserPassword, err = os.ReadFile(sshUserPassFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", sshUserPassFile)) fatal(errors.Wrapf(err, "error reading %s", sshUserPassFile))
} }
sshUserPassword = bytes.TrimRightFunc(sshUserPassword, unicode.IsSpace) sshUserPassword = bytes.TrimRightFunc(sshUserPassword, unicode.IsSpace)
@ -122,7 +121,7 @@ To get a linked authority token:
var issuerPassword []byte var issuerPassword []byte
if issuerPassFile != "" { if issuerPassFile != "" {
if issuerPassword, err = ioutil.ReadFile(issuerPassFile); err != nil { if issuerPassword, err = os.ReadFile(issuerPassFile); err != nil {
fatal(errors.Wrapf(err, "error reading %s", issuerPassFile)) fatal(errors.Wrapf(err, "error reading %s", issuerPassFile))
} }
issuerPassword = bytes.TrimRightFunc(issuerPassword, unicode.IsSpace) issuerPassword = bytes.TrimRightFunc(issuerPassword, unicode.IsSpace)

View file

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "os"
"unicode" "unicode"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -72,14 +72,14 @@ func exportAction(ctx *cli.Context) error {
} }
if passwordFile != "" { if passwordFile != "" {
b, err := ioutil.ReadFile(passwordFile) b, err := os.ReadFile(passwordFile)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", passwordFile) return errors.Wrapf(err, "error reading %s", passwordFile)
} }
cfg.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace)) cfg.Password = string(bytes.TrimRightFunc(b, unicode.IsSpace))
} }
if issuerPasswordFile != "" { if issuerPasswordFile != "" {
b, err := ioutil.ReadFile(issuerPasswordFile) b, err := os.ReadFile(issuerPasswordFile)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", issuerPasswordFile) return errors.Wrapf(err, "error reading %s", issuerPasswordFile)
} }

View file

@ -3,7 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil" "io"
"os" "os"
"time" "time"
@ -32,7 +32,7 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -3,7 +3,7 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil" "io"
"os" "os"
"time" "time"
@ -32,7 +32,7 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
b, err := ioutil.ReadAll(resp.Body) b, err := io.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
panic(err) panic(err)

15
go.mod
View file

@ -1,6 +1,6 @@
module github.com/smallstep/certificates module github.com/smallstep/certificates
go 1.15 go 1.16
require ( require (
cloud.google.com/go v0.83.0 cloud.google.com/go v0.83.0
@ -29,10 +29,10 @@ require (
github.com/rs/xid v1.2.1 github.com/rs/xid v1.2.1
github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus v1.4.2
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262
github.com/smallstep/nosql v0.3.8 github.com/smallstep/nosql v0.3.9
github.com/urfave/cli v1.22.4 github.com/urfave/cli v1.22.4
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352
go.step.sm/cli-utils v0.6.2 go.step.sm/cli-utils v0.7.0
go.step.sm/crypto v0.13.0 go.step.sm/crypto v0.13.0
go.step.sm/linkedca v0.7.0 go.step.sm/linkedca v0.7.0
golang.org/x/crypto v0.0.0-20210915214749-c084706c2272 golang.org/x/crypto v0.0.0-20210915214749-c084706c2272
@ -44,7 +44,8 @@ require (
gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/square/go-jose.v2 v2.6.0
) )
// replace github.com/smallstep/nosql => ../nosql //replace github.com/smallstep/nosql => ../nosql
// replace go.step.sm/crypto => ../crypto
// replace go.step.sm/cli-utils => ../cli-utils //replace go.step.sm/crypto => ../crypto
// replace go.step.sm/linkedca => ../linkedca
//replace go.step.sm/cli-utils => ../cli-utils

11
go.sum
View file

@ -367,10 +367,8 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.0 h1:s5hAObm+yFO5uHYt5dYjxi2rXrsnmRpJx4OYvIWUaQs=
github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
@ -496,8 +494,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd
github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE= github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc=
github.com/smallstep/nosql v0.3.8 h1:1/EWUbbEdz9ai0g9Fd09VekVjtxp+5+gIHpV2PdwW3o= github.com/smallstep/nosql v0.3.9 h1:YPy5PR3PXClqmpFaVv0wfXDXDc7NXGBE1auyU2c87dc=
github.com/smallstep/nosql v0.3.8/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs= github.com/smallstep/nosql v0.3.9/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
@ -561,8 +559,8 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.step.sm/cli-utils v0.6.2 h1:ofa3G/EqE3dTDXmzoXHDZr18qJZoFsKSzbzuF+mxuZU= go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds=
go.step.sm/cli-utils v0.6.2/go.mod h1:0tZ8F2QwLgD6KbKj4nrQZhMakTasEAnOcW3Ekc5pnrA= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E=
go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0=
go.step.sm/crypto v0.13.0 h1:mQuP9Uu2FNmqCJNO0OTbvolnYXzONy4wdUBtUVcP1s8= go.step.sm/crypto v0.13.0 h1:mQuP9Uu2FNmqCJNO0OTbvolnYXzONy4wdUBtUVcP1s8=
go.step.sm/crypto v0.13.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg= go.step.sm/crypto v0.13.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg=
@ -956,7 +954,6 @@ google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= gopkg.in/cheggaaa/pb.v1 v1.0.25/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=

View file

@ -4,7 +4,7 @@ import (
"context" "context"
"crypto" "crypto"
"fmt" "fmt"
"io/ioutil" "os"
"reflect" "reflect"
"testing" "testing"
@ -165,7 +165,7 @@ func TestCloudKMS_Close(t *testing.T) {
func TestCloudKMS_CreateSigner(t *testing.T) { func TestCloudKMS_CreateSigner(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1" keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -223,7 +223,7 @@ func TestCloudKMS_CreateKey(t *testing.T) {
testError := fmt.Errorf("an error") testError := fmt.Errorf("an error")
alreadyExists := status.Error(codes.AlreadyExists, "already exists") alreadyExists := status.Error(codes.AlreadyExists, "already exists")
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -389,7 +389,7 @@ func TestCloudKMS_GetPublicKey(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1" keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
testError := fmt.Errorf("an error") testError := fmt.Errorf("an error")
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -7,7 +7,7 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io" "io"
"io/ioutil" "os"
"reflect" "reflect"
"testing" "testing"
@ -17,7 +17,7 @@ import (
) )
func Test_newSigner(t *testing.T) { func Test_newSigner(t *testing.T) {
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -70,7 +70,7 @@ func Test_newSigner(t *testing.T) {
} }
func Test_signer_Public(t *testing.T) { func Test_signer_Public(t *testing.T) {
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -159,7 +159,7 @@ func Test_signer_Sign(t *testing.T) {
} }
func TestSigner_SignatureAlgorithm(t *testing.T) { func TestSigner_SignatureAlgorithm(t *testing.T) {
pemBytes, err := ioutil.ReadFile("testdata/pub.pem") pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -11,7 +11,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io/ioutil" "os"
"reflect" "reflect"
"testing" "testing"
@ -78,7 +78,7 @@ func TestSoftKMS_CreateSigner(t *testing.T) {
} }
// Read and decode file using standard packages // Read and decode file using standard packages
b, err := ioutil.ReadFile("testdata/priv.pem") b, err := os.ReadFile("testdata/priv.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -234,7 +234,7 @@ func TestSoftKMS_CreateKey(t *testing.T) {
} }
func TestSoftKMS_GetPublicKey(t *testing.T) { func TestSoftKMS_GetPublicKey(t *testing.T) {
b, err := ioutil.ReadFile("testdata/pub.pem") b, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -332,7 +332,7 @@ func TestSoftKMS_CreateDecrypter(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err := ioutil.ReadFile("testdata/rsa.priv.pem") b, err := os.ReadFile("testdata/rsa.priv.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -9,7 +9,6 @@ import (
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"io/ioutil"
"net" "net"
"os" "os"
"os/exec" "os/exec"
@ -202,7 +201,7 @@ func TestNew(t *testing.T) {
}) })
// Load ssh test fixtures // Load ssh test fixtures
b, err := ioutil.ReadFile("testdata/ssh") b, err := os.ReadFile("testdata/ssh")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -290,7 +289,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) {
} }
// Read and decode file using standard packages // Read and decode file using standard packages
b, err := ioutil.ReadFile("testdata/priv.pem") b, err := os.ReadFile("testdata/priv.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -315,7 +314,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) {
}) })
// Load ssh test fixtures // Load ssh test fixtures
sshPubKeyStr, err := ioutil.ReadFile("testdata/ssh.pub") sshPubKeyStr, err := os.ReadFile("testdata/ssh.pub")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -323,7 +322,7 @@ func TestSSHAgentKMS_CreateSigner(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err = ioutil.ReadFile("testdata/ssh") b, err = os.ReadFile("testdata/ssh")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -499,7 +498,7 @@ func TestSSHAgentKMS_CreateKey(t *testing.T) {
*/ */
func TestSSHAgentKMS_GetPublicKey(t *testing.T) { func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
b, err := ioutil.ReadFile("testdata/pub.pem") b, err := os.ReadFile("testdata/pub.pem")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -510,7 +509,7 @@ func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
} }
// Load ssh test fixtures // Load ssh test fixtures
b, err = ioutil.ReadFile("testdata/ssh.pub") b, err = os.ReadFile("testdata/ssh.pub")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -518,7 +517,7 @@ func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
b, err = ioutil.ReadFile("testdata/ssh") b, err = os.ReadFile("testdata/ssh")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -3,8 +3,8 @@ package uri
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"io/ioutil"
"net/url" "net/url"
"os"
"strings" "strings"
"unicode" "unicode"
@ -140,7 +140,7 @@ func readFile(path string) ([]byte, error) {
if err == nil && (u.Scheme == "" || u.Scheme == "file") && u.Path != "" { if err == nil && (u.Scheme == "" || u.Scheme == "file") && u.Path != "" {
path = u.Path path = u.Path
} }
b, err := ioutil.ReadFile(path) b, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", path) return nil, errors.Wrapf(err, "error reading %s", path)
} }

View file

@ -29,9 +29,9 @@ import (
"github.com/smallstep/certificates/kms" "github.com/smallstep/certificates/kms"
kmsapi "github.com/smallstep/certificates/kms/apiv1" kmsapi "github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/errs"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
@ -87,44 +87,50 @@ const (
) )
// GetDBPath returns the path where the file-system persistence is stored // GetDBPath returns the path where the file-system persistence is stored
// based on the STEPPATH environment variable. // based on the $(step path).
func GetDBPath() string { func GetDBPath() string {
return filepath.Join(config.StepPath(), dbPath) return filepath.Join(step.Path(), dbPath)
} }
// GetConfigPath returns the directory where the configuration files are stored // GetConfigPath returns the directory where the configuration files are stored
// based on the STEPPATH environment variable. // based on the $(step path).
func GetConfigPath() string { func GetConfigPath() string {
return filepath.Join(config.StepPath(), configPath) return filepath.Join(step.Path(), configPath)
}
// GetProfileConfigPath returns the directory where the profile configuration
// files are stored based on the $(step path).
func GetProfileConfigPath() string {
return filepath.Join(step.ProfilePath(), configPath)
} }
// GetPublicPath returns the directory where the public keys are stored based on // GetPublicPath returns the directory where the public keys are stored based on
// the STEPPATH environment variable. // the $(step path).
func GetPublicPath() string { func GetPublicPath() string {
return filepath.Join(config.StepPath(), publicPath) return filepath.Join(step.Path(), publicPath)
} }
// GetSecretsPath returns the directory where the private keys are stored based // GetSecretsPath returns the directory where the private keys are stored based
// on the STEPPATH environment variable. // on the $(step path).
func GetSecretsPath() string { func GetSecretsPath() string {
return filepath.Join(config.StepPath(), privatePath) return filepath.Join(step.Path(), privatePath)
} }
// GetRootCAPath returns the path where the root CA is stored based on the // GetRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable. // $(step path).
func GetRootCAPath() string { func GetRootCAPath() string {
return filepath.Join(config.StepPath(), publicPath, "root_ca.crt") return filepath.Join(step.Path(), publicPath, "root_ca.crt")
} }
// GetOTTKeyPath returns the path where the one-time token key is stored based // GetOTTKeyPath returns the path where the one-time token key is stored based
// on the STEPPATH environment variable. // on the $(step path).
func GetOTTKeyPath() string { func GetOTTKeyPath() string {
return filepath.Join(config.StepPath(), privatePath, "ott_key") return filepath.Join(step.Path(), privatePath, "ott_key")
} }
// GetTemplatesPath returns the path where the templates are stored. // GetTemplatesPath returns the path where the templates are stored.
func GetTemplatesPath() string { func GetTemplatesPath() string {
return filepath.Join(config.StepPath(), templatesPath) return filepath.Join(step.Path(), templatesPath)
} }
// GetProvisioners returns the map of provisioners on the given CA. // GetProvisioners returns the map of provisioners on the given CA.
@ -286,20 +292,22 @@ func WithKeyURIs(rootKey, intermediateKey, hostKey, userKey string) Option {
// PKI represents the Public Key Infrastructure used by a certificate authority. // PKI represents the Public Key Infrastructure used by a certificate authority.
type PKI struct { type PKI struct {
linkedca.Configuration linkedca.Configuration
Defaults linkedca.Defaults Defaults linkedca.Defaults
casOptions apiv1.Options casOptions apiv1.Options
caService apiv1.CertificateAuthorityService caService apiv1.CertificateAuthorityService
caCreator apiv1.CertificateAuthorityCreator caCreator apiv1.CertificateAuthorityCreator
keyManager kmsapi.KeyManager keyManager kmsapi.KeyManager
config string config string
defaults string defaults string
ottPublicKey *jose.JSONWebKey profileDefaults string
ottPrivateKey *jose.JSONWebEncryption ottPublicKey *jose.JSONWebKey
options *options ottPrivateKey *jose.JSONWebEncryption
options *options
} }
// New creates a new PKI configuration. // New creates a new PKI configuration.
func New(o apiv1.Options, opts ...Option) (*PKI, error) { func New(o apiv1.Options, opts ...Option) (*PKI, error) {
currentCtx := step.Contexts().GetCurrent()
caService, err := cas.New(context.Background(), o) caService, err := cas.New(context.Background(), o)
if err != nil { if err != nil {
return nil, err return nil, err
@ -358,6 +366,9 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
cfg = GetConfigPath() cfg = GetConfigPath()
// Create directories // Create directories
dirs := []string{public, private, cfg, GetTemplatesPath()} dirs := []string{public, private, cfg, GetTemplatesPath()}
if currentCtx != nil {
dirs = append(dirs, GetProfileConfigPath())
}
for _, name := range dirs { for _, name := range dirs {
if _, err := os.Stat(name); os.IsNotExist(err) { if _, err := os.Stat(name); os.IsNotExist(err) {
if err = os.MkdirAll(name, 0700); err != nil { if err = os.MkdirAll(name, 0700); err != nil {
@ -415,6 +426,10 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
if p.defaults, err = getPath(cfg, "defaults.json"); err != nil { if p.defaults, err = getPath(cfg, "defaults.json"); err != nil {
return nil, err return nil, err
} }
if currentCtx != nil {
p.profileDefaults = currentCtx.ProfileDefaultsFile()
}
if p.config, err = getPath(cfg, "ca.json"); err != nil { if p.config, err = getPath(cfg, "ca.json"); err != nil {
return nil, err return nil, err
} }
@ -944,6 +959,18 @@ func (p *PKI) Save(opt ...ConfigOption) error {
if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil { if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil {
return errs.FileError(err, p.defaults) return errs.FileError(err, p.defaults)
} }
// If we're using contexts then write a blank object to the default profile
// configuration location.
if p.profileDefaults != "" {
if _, err := os.Stat(p.profileDefaults); os.IsNotExist(err) {
// Write with 0600 to be consistent with directories structure.
if err = fileutil.WriteFile(p.profileDefaults, []byte("{}"), 0600); err != nil {
return errs.FileError(err, p.profileDefaults)
}
} else if err != nil {
return errs.FileError(err, p.profileDefaults)
}
}
// Generate and write templates // Generate and write templates
if err := generateTemplates(cfg.Templates); err != nil { if err := generateTemplates(cfg.Templates); err != nil {
@ -958,6 +985,9 @@ func (p *PKI) Save(opt ...ConfigOption) error {
} }
ui.PrintSelected("Default configuration", p.defaults) ui.PrintSelected("Default configuration", p.defaults)
if p.profileDefaults != "" {
ui.PrintSelected("Default profile configuration", p.profileDefaults)
}
ui.PrintSelected("Certificate Authority configuration", p.config) ui.PrintSelected("Certificate Authority configuration", p.config)
if p.options.deploymentType != LinkedDeployment { if p.options.deploymentType != LinkedDeployment {
ui.Println() ui.Println()

View file

@ -6,9 +6,9 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/errs"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
) )
// getTemplates returns all the templates enabled // getTemplates returns all the templates enabled
@ -44,7 +44,7 @@ func generateTemplates(t *templates.Templates) error {
if !ok { if !ok {
return errors.Errorf("template %s does not exists", t.Name) return errors.Errorf("template %s does not exists", t.Name)
} }
if err := fileutil.WriteFile(config.StepAbs(t.TemplatePath), []byte(data), 0644); err != nil { if err := fileutil.WriteFile(step.Abs(t.TemplatePath), []byte(data), 0644); err != nil {
return err return err
} }
} }
@ -53,7 +53,7 @@ func generateTemplates(t *templates.Templates) error {
if !ok { if !ok {
return errors.Errorf("template %s does not exists", t.Name) return errors.Errorf("template %s does not exists", t.Name)
} }
if err := fileutil.WriteFile(config.StepAbs(t.TemplatePath), []byte(data), 0644); err != nil { if err := fileutil.WriteFile(step.Abs(t.TemplatePath), []byte(data), 0644); err != nil {
return err return err
} }
} }

View file

@ -5,7 +5,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -167,7 +166,7 @@ func decodeSCEPRequest(r *http.Request) (SCEPRequest, error) {
return SCEPRequest{}, errors.Errorf("unsupported operation: %s", operation) return SCEPRequest{}, errors.Errorf("unsupported operation: %s", operation)
} }
case http.MethodPost: case http.MethodPost:
body, err := ioutil.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize))
if err != nil { if err != nil {
return SCEPRequest{}, err return SCEPRequest{}, err
} }

View file

@ -2,7 +2,6 @@ package templates
import ( import (
"bytes" "bytes"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -10,8 +9,8 @@ import (
"github.com/Masterminds/sprig/v3" "github.com/Masterminds/sprig/v3"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
) )
// TemplateType defines how a template will be written in disk. // TemplateType defines how a template will be written in disk.
@ -20,6 +19,9 @@ type TemplateType string
const ( const (
// Snippet will mark a template as a part of a file. // Snippet will mark a template as a part of a file.
Snippet TemplateType = "snippet" Snippet TemplateType = "snippet"
// PrependLine is a template for prepending a single line to a file. If the
// line already exists in the file it will be removed first.
PrependLine TemplateType = "prepend-line"
// File will mark a templates as a full file. // File will mark a templates as a full file.
File TemplateType = "file" File TemplateType = "file"
// Directory will mark a template as a directory. // Directory will mark a template as a directory.
@ -99,7 +101,7 @@ func (t *SSHTemplates) Validate() (err error) {
return return
} }
// Template represents on template file. // Template represents a template file.
type Template struct { type Template struct {
*template.Template *template.Template
Name string `json:"name"` Name string `json:"name"`
@ -118,8 +120,8 @@ func (t *Template) Validate() error {
return nil return nil
case t.Name == "": case t.Name == "":
return errors.New("template name cannot be empty") return errors.New("template name cannot be empty")
case t.Type != Snippet && t.Type != File && t.Type != Directory: case t.Type != Snippet && t.Type != File && t.Type != Directory && t.Type != PrependLine:
return errors.Errorf("invalid template type %s, it must be %s, %s, or %s", t.Type, Snippet, File, Directory) return errors.Errorf("invalid template type %s, it must be %s, %s, %s, or %s", t.Type, Snippet, PrependLine, File, Directory)
case t.TemplatePath == "" && t.Type != Directory && len(t.Content) == 0: case t.TemplatePath == "" && t.Type != Directory && len(t.Content) == 0:
return errors.New("template template cannot be empty") return errors.New("template template cannot be empty")
case t.TemplatePath != "" && t.Type == Directory: case t.TemplatePath != "" && t.Type == Directory:
@ -132,7 +134,7 @@ func (t *Template) Validate() error {
if t.TemplatePath != "" { if t.TemplatePath != "" {
// Check for file // Check for file
st, err := os.Stat(config.StepAbs(t.TemplatePath)) st, err := os.Stat(step.Abs(t.TemplatePath))
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", t.TemplatePath) return errors.Wrapf(err, "error reading %s", t.TemplatePath)
} }
@ -166,8 +168,8 @@ func (t *Template) Load() error {
if t.Template == nil && t.Type != Directory { if t.Template == nil && t.Type != Directory {
switch { switch {
case t.TemplatePath != "": case t.TemplatePath != "":
filename := config.StepAbs(t.TemplatePath) filename := step.Abs(t.TemplatePath)
b, err := ioutil.ReadFile(filename) b, err := os.ReadFile(filename)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", filename) return errors.Wrapf(err, "error reading %s", filename)
} }
@ -247,7 +249,10 @@ type Output struct {
// Write writes the Output to the filesystem as a directory, file or snippet. // Write writes the Output to the filesystem as a directory, file or snippet.
func (o *Output) Write() error { func (o *Output) Write() error {
path := config.StepAbs(o.Path) // Replace ${STEPPATH} with the base step path.
o.Path = strings.ReplaceAll(o.Path, "${STEPPATH}", step.BasePath())
path := step.Abs(o.Path)
if o.Type == Directory { if o.Type == Directory {
return mkdir(path, 0700) return mkdir(path, 0700)
} }
@ -257,11 +262,17 @@ func (o *Output) Write() error {
return err return err
} }
if o.Type == File { switch o.Type {
case File:
return fileutil.WriteFile(path, o.Content, 0600) return fileutil.WriteFile(path, o.Content, 0600)
case Snippet:
return fileutil.WriteSnippet(path, o.Content, 0600)
case PrependLine:
return fileutil.PrependLine(path, o.Content, 0600)
default:
// Default to using a Snippet type if the type is not known.
return fileutil.WriteSnippet(path, o.Content, 0600)
} }
return fileutil.WriteSnippet(path, o.Content, 0600)
} }
func mkdir(path string, perm os.FileMode) error { func mkdir(path string, perm os.FileMode) error {

View file

@ -4,6 +4,10 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// SSHTemplateVersionKey is a key that can be submitted by a client to select
// the template version that will be returned by the server.
var SSHTemplateVersionKey = "StepSSHTemplateVersion"
// Step represents the default variables available in the CA. // Step represents the default variables available in the CA.
type Step struct { type Step struct {
SSH StepSSH SSH StepSSH
@ -22,16 +26,23 @@ type StepSSH struct {
var DefaultSSHTemplates = SSHTemplates{ var DefaultSSHTemplates = SSHTemplates{
User: []Template{ User: []Template{
{ {
Name: "include.tpl", Name: "config.tpl",
Type: Snippet, Type: Snippet,
TemplatePath: "templates/ssh/include.tpl", TemplatePath: "templates/ssh/config.tpl",
Path: "~/.ssh/config", Path: "~/.ssh/config",
Comment: "#", Comment: "#",
}, },
{ {
Name: "config.tpl", Name: "step_includes.tpl",
Type: PrependLine,
TemplatePath: "templates/ssh/step_includes.tpl",
Path: "${STEPPATH}/ssh/includes",
Comment: "#",
},
{
Name: "step_config.tpl",
Type: File, Type: File,
TemplatePath: "templates/ssh/config.tpl", TemplatePath: "templates/ssh/step_config.tpl",
Path: "ssh/config", Path: "ssh/config",
Comment: "#", Comment: "#",
}, },
@ -64,30 +75,43 @@ var DefaultSSHTemplates = SSHTemplates{
// DefaultSSHTemplateData contains the data of the default templates used on ssh. // DefaultSSHTemplateData contains the data of the default templates used on ssh.
var DefaultSSHTemplateData = map[string]string{ var DefaultSSHTemplateData = map[string]string{
// include.tpl adds the step ssh config file. // config.tpl adds the step ssh config file.
// //
// Note: on windows `Include C:\...` is treated as a relative path. // Note: on windows `Include C:\...` is treated as a relative path.
"include.tpl": `Host * "config.tpl": `Host *
{{- if or .User.GOOS "none" | eq "windows" }} {{- if or .User.GOOS "none" | eq "windows" }}
Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config" {{- if .User.StepBasePath }}
Include "{{ .User.StepBasePath | replace "\\" "/" | trimPrefix "C:" }}/ssh/includes"
{{- else }} {{- else }}
Include "{{.User.StepPath}}/ssh/config" Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/includes"
{{- end }}
{{- else }}
{{- if .User.StepBasePath }}
Include "{{.User.StepBasePath}}/ssh/includes"
{{- else }}
Include "{{.User.StepPath}}/ssh/includes"
{{- end }}
{{- end }}`, {{- end }}`,
// config.tpl is the step ssh config file, it includes the Match rule and // step_includes.tpl adds the step ssh config file.
//
// Note: on windows `Include C:\...` is treated as a relative path.
"step_includes.tpl": `{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}`,
// step_config.tpl is the step ssh config file, it includes the Match rule and
// references the step known_hosts file. // references the step known_hosts file.
// //
// Note: on windows ProxyCommand requires the full path // Note: on windows ProxyCommand requires the full path
"config.tpl": `Match exec "step ssh check-host %h" "step_config.tpl": `Match exec "step ssh check-host{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %h"
{{- if .User.User }} {{- if .User.User }}
User {{.User.User}} User {{.User.User}}
{{- end }} {{- end }}
{{- if or .User.GOOS "none" | eq "windows" }} {{- if or .User.GOOS "none" | eq "windows" }}
UserKnownHostsFile "{{.User.StepPath}}\ssh\known_hosts" UserKnownHostsFile "{{.User.StepPath}}\ssh\known_hosts"
ProxyCommand C:\Windows\System32\cmd.exe /c step ssh proxycommand %r %h %p ProxyCommand C:\Windows\System32\cmd.exe /c step ssh proxycommand{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %r %h %p
{{- else }} {{- else }}
UserKnownHostsFile "{{.User.StepPath}}/ssh/known_hosts" UserKnownHostsFile "{{.User.StepPath}}/ssh/known_hosts"
ProxyCommand step ssh proxycommand %r %h %p ProxyCommand step ssh proxycommand{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %r %h %p
{{- end }} {{- end }}
`, `,