Merge pull request #847 from smallstep/herman/allow-deny-next
Refactor allow/deny (WIP)
This commit is contained in:
commit
c45d177d52
112 changed files with 4101 additions and 1274 deletions
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
|
@ -12,7 +12,7 @@ jobs:
|
|||
runs-on: ubuntu-20.04
|
||||
strategy:
|
||||
matrix:
|
||||
go: [ '1.15', '1.16', '1.17' ]
|
||||
go: [ '1.17', '1.18' ]
|
||||
outputs:
|
||||
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
|
||||
steps:
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
|||
uses: golangci/golangci-lint-action@v2
|
||||
with:
|
||||
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||
version: 'v1.44.0'
|
||||
version: 'v1.45.0'
|
||||
|
||||
# Optional: working directory, useful for monorepos
|
||||
# working-directory: somedir
|
||||
|
@ -106,7 +106,7 @@ jobs:
|
|||
name: Set up Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: 1.17
|
||||
go-version: 1.18
|
||||
-
|
||||
name: APT Install
|
||||
id: aptInstall
|
||||
|
@ -159,7 +159,7 @@ jobs:
|
|||
name: Setup Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: '1.17'
|
||||
go-version: '1.18'
|
||||
-
|
||||
name: Install cosign
|
||||
uses: sigstore/cosign-installer@v1.1.0
|
||||
|
|
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
|
@ -14,7 +14,7 @@ jobs:
|
|||
runs-on: ubuntu-20.04
|
||||
strategy:
|
||||
matrix:
|
||||
go: [ '1.16', '1.17' ]
|
||||
go: [ '1.17', '1.18' ]
|
||||
steps:
|
||||
-
|
||||
name: Checkout
|
||||
|
@ -33,7 +33,7 @@ jobs:
|
|||
uses: golangci/golangci-lint-action@v2
|
||||
with:
|
||||
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||
version: 'v1.44.0'
|
||||
version: 'v1.45.0'
|
||||
|
||||
# Optional: working directory, useful for monorepos
|
||||
# working-directory: somedir
|
||||
|
@ -58,7 +58,7 @@ jobs:
|
|||
run: V=1 make ci
|
||||
-
|
||||
name: Codecov
|
||||
if: matrix.go == '1.17'
|
||||
if: matrix.go == '1.18'
|
||||
uses: codecov/codecov-action@v1.2.1
|
||||
with:
|
||||
file: ./coverage.out # optional
|
||||
|
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -19,3 +19,4 @@ coverage.txt
|
|||
output
|
||||
vendor
|
||||
.idea
|
||||
.envrc
|
||||
|
|
|
@ -19,6 +19,7 @@ builds:
|
|||
- linux_386
|
||||
- linux_amd64
|
||||
- linux_arm64
|
||||
- linux_arm_5
|
||||
- linux_arm_6
|
||||
- linux_arm_7
|
||||
- windows_amd64
|
||||
|
@ -39,6 +40,7 @@ builds:
|
|||
- linux_386
|
||||
- linux_amd64
|
||||
- linux_arm64
|
||||
- linux_arm_5
|
||||
- linux_arm_6
|
||||
- linux_arm_7
|
||||
- windows_amd64
|
||||
|
@ -59,6 +61,7 @@ builds:
|
|||
- linux_386
|
||||
- linux_amd64
|
||||
- linux_arm64
|
||||
- linux_arm_5
|
||||
- linux_arm_6
|
||||
- linux_arm_7
|
||||
- windows_amd64
|
||||
|
|
|
@ -6,7 +6,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||
|
||||
## [Unreleased - 0.18.3] - DATE
|
||||
### Added
|
||||
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`.
|
||||
### Changed
|
||||
- Made SCEP CA URL paths dynamic
|
||||
- Support two latest versions of Go (1.17, 1.18)
|
||||
### Deprecated
|
||||
### Removed
|
||||
### Fixed
|
||||
|
@ -15,6 +18,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
|
|||
## [0.18.2] - 2022-03-01
|
||||
### Added
|
||||
- Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner.
|
||||
- [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering
|
||||
out database drivers using Go tags. For example, using the Go flag
|
||||
`--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx
|
||||
driver for PostgreSQL.
|
||||
### Changed
|
||||
- IPv6 addresses are normalized as IP addresses instead of hostnames.
|
||||
- More descriptive JWK decryption error message.
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/go-chi/chi"
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/crypto/randutil"
|
||||
)
|
||||
|
||||
|
@ -104,10 +105,13 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
|
|||
// management of allowed/denied names based on just the name, without having bound to EAB. Still,
|
||||
// EAB is not illogical, because that's the way Accounts are connected to an external system and
|
||||
// thus make sense to also set the allowed/denied names based on that info.
|
||||
// TODO: also perform check on the authority level here already, so that challenges are not performed
|
||||
// and after that the CA fails to sign it. (i.e. h.ca function?)
|
||||
|
||||
for _, identifier := range nor.Identifiers {
|
||||
// TODO: gather all errors, so that we can build subproblems; include the nor.Validate() error here too, like in example?
|
||||
err = prov.AuthorizeOrderIdentifier(ctx, identifier.Value)
|
||||
orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value}
|
||||
err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier)
|
||||
if err != nil {
|
||||
api.WriteError(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
|
||||
return
|
||||
|
|
|
@ -30,7 +30,7 @@ var clock Clock
|
|||
// Provisioner is an interface that implements a subset of the provisioner.Interface --
|
||||
// only those methods required by the ACME api/authority.
|
||||
type Provisioner interface {
|
||||
AuthorizeOrderIdentifier(ctx context.Context, identifier string) error
|
||||
AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error
|
||||
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
|
||||
AuthorizeRevoke(ctx context.Context, token string) error
|
||||
GetID() string
|
||||
|
@ -45,7 +45,7 @@ type MockProvisioner struct {
|
|||
Merr error
|
||||
MgetID func() string
|
||||
MgetName func() string
|
||||
MauthorizeOrderIdentifier func(ctx context.Context, identifier string) error
|
||||
MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error
|
||||
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
MauthorizeRevoke func(ctx context.Context, token string) error
|
||||
MdefaultTLSCertDuration func() time.Duration
|
||||
|
@ -61,7 +61,7 @@ func (m *MockProvisioner) GetName() string {
|
|||
}
|
||||
|
||||
// AuthorizeOrderIdentifiers mock
|
||||
func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier string) error {
|
||||
func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
|
||||
if m.MauthorizeOrderIdentifier != nil {
|
||||
return m.MauthorizeOrderIdentifier(ctx, identifier)
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
|
@ -33,6 +34,7 @@ type Authority interface {
|
|||
// context specifies the Authorize[Sign|Revoke|etc.] method.
|
||||
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
|
||||
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||
GetTLSOptions() *config.TLSOptions
|
||||
Root(shasum string) (*x509.Certificate, error)
|
||||
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
|
@ -43,7 +45,7 @@ type Authority interface {
|
|||
GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
|
||||
Revoke(context.Context, *authority.RevokeOptions) error
|
||||
GetEncryptedKey(kid string) (string, error)
|
||||
GetRoots() (federation []*x509.Certificate, err error)
|
||||
GetRoots() ([]*x509.Certificate, error)
|
||||
GetFederation() ([]*x509.Certificate, error)
|
||||
Version() authority.Version
|
||||
}
|
||||
|
|
133
api/api_test.go
133
api/api_test.go
|
@ -13,6 +13,7 @@ import (
|
|||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
|
@ -27,14 +28,17 @@ import (
|
|||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"go.step.sm/crypto/jose"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -171,6 +175,7 @@ type mockAuthority struct {
|
|||
ret1, ret2 interface{}
|
||||
err error
|
||||
authorizeSign func(ott string) ([]provisioner.SignOption, error)
|
||||
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
|
||||
getTLSOptions func() *authority.TLSOptions
|
||||
root func(shasum string) (*x509.Certificate, error)
|
||||
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
|
||||
|
@ -208,6 +213,13 @@ func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, err
|
|||
return m.ret1.([]provisioner.SignOption), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||
if m.authorizeRenewToken != nil {
|
||||
return m.authorizeRenewToken(ctx, ott)
|
||||
}
|
||||
return m.ret1.(*x509.Certificate), m.err
|
||||
}
|
||||
|
||||
func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions {
|
||||
if m.getTLSOptions != nil {
|
||||
return m.getTLSOptions()
|
||||
|
@ -920,48 +932,141 @@ func Test_caHandler_Renew(t *testing.T) {
|
|||
cs := &tls.ConnectionState{
|
||||
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
|
||||
}
|
||||
|
||||
// Prepare root and leaf for renew after expiry test.
|
||||
now := time.Now()
|
||||
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
leafPub, leafPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
root := &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "Test Root CA"},
|
||||
PublicKey: rootPub,
|
||||
KeyUsage: x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
NotBefore: now.Add(-2 * time.Hour),
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}
|
||||
root, err = x509util.CreateCertificate(root, root, rootPub, rootPriv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expiredLeaf := &x509.Certificate{
|
||||
Subject: pkix.Name{CommonName: "Leaf certificate"},
|
||||
PublicKey: leafPub,
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
NotBefore: now.Add(-time.Hour),
|
||||
NotAfter: now.Add(-time.Minute),
|
||||
EmailAddresses: []string{"test@example.org"},
|
||||
}
|
||||
expiredLeaf, err = x509util.CreateCertificate(expiredLeaf, root, leafPub, rootPriv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Generate renew after expiry token
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithType("JWT")
|
||||
so.WithHeader("x5cInsecure", []string{base64.StdEncoding.EncodeToString(expiredLeaf.Raw)})
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: leafPriv}, so)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
generateX5cToken := func(claims jose.Claims) string {
|
||||
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tls *tls.ConnectionState
|
||||
header http.Header
|
||||
cert *x509.Certificate
|
||||
root *x509.Certificate
|
||||
err error
|
||||
statusCode int
|
||||
}{
|
||||
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||
{"no tls", nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest},
|
||||
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
|
||||
{"ok", cs, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
|
||||
{"ok renew after expiry", &tls.ConnectionState{}, http.Header{
|
||||
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
|
||||
NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
})},
|
||||
}, expiredLeaf, root, nil, http.StatusCreated},
|
||||
{"no tls", nil, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, nil, http.StatusBadRequest},
|
||||
{"renew error", cs, nil, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
|
||||
{"fail expired token", &tls.ConnectionState{}, http.Header{
|
||||
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
|
||||
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
|
||||
})},
|
||||
}, expiredLeaf, root, errs.Forbidden("an error"), http.StatusUnauthorized},
|
||||
{"fail invalid root", &tls.ConnectionState{}, http.Header{
|
||||
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
|
||||
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
|
||||
})},
|
||||
}, expiredLeaf, parseCertificate(rootPEM), errs.Forbidden("an error"), http.StatusUnauthorized},
|
||||
}
|
||||
|
||||
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := New(&mockAuthority{
|
||||
ret1: tt.cert, ret2: tt.root, err: tt.err,
|
||||
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
|
||||
if err != nil {
|
||||
return nil, errs.Unauthorized(err.Error())
|
||||
}
|
||||
var claims jose.Claims
|
||||
if err := jwt.Claims(chain[0][0].PublicKey, &claims); err != nil {
|
||||
return nil, errs.Unauthorized(err.Error())
|
||||
}
|
||||
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||
Time: now,
|
||||
}, time.Minute); err != nil {
|
||||
return nil, errs.Unauthorized(err.Error())
|
||||
}
|
||||
return chain[0][0], nil
|
||||
},
|
||||
getTLSOptions: func() *authority.TLSOptions {
|
||||
return nil
|
||||
},
|
||||
}).(*caHandler)
|
||||
req := httptest.NewRequest("POST", "http://example.com/renew", nil)
|
||||
req.TLS = tt.tls
|
||||
req.Header = tt.header
|
||||
w := httptest.NewRecorder()
|
||||
h.Renew(logging.NewResponseLogger(w), req)
|
||||
res := w.Result()
|
||||
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
}
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
t.Errorf("caHandler.Renew unexpected error = %v", err)
|
||||
}
|
||||
if res.StatusCode != tt.statusCode {
|
||||
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
|
||||
t.Errorf("%s", body)
|
||||
}
|
||||
|
||||
if tt.statusCode < http.StatusBadRequest {
|
||||
expected := []byte(`{"crt":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `",` +
|
||||
`"ca":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `",` +
|
||||
`"certChain":["` +
|
||||
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `","` +
|
||||
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `"]}`)
|
||||
|
||||
if !bytes.Equal(bytes.TrimSpace(body), expected) {
|
||||
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected)
|
||||
t.Errorf("caHandler.Root Body = \n%s, wants \n%s", body, expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -7,7 +7,9 @@ import (
|
|||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/certificates/acme"
|
||||
"github.com/smallstep/certificates/api/log"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
|
@ -59,6 +61,6 @@ func WriteError(w http.ResponseWriter, err error) {
|
|||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(err); err != nil {
|
||||
LogError(w, err)
|
||||
log.Error(w, err)
|
||||
}
|
||||
}
|
||||
|
|
47
api/log/log.go
Normal file
47
api/log/log.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
// Package log implements API-related logging helpers.
|
||||
package log
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
// Error adds to the response writer the given error if it implements
|
||||
// logging.ResponseLogger. If it does not implement it, then writes the error
|
||||
// using the log package.
|
||||
func Error(rw http.ResponseWriter, err error) {
|
||||
if rl, ok := rw.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"error": err,
|
||||
})
|
||||
} else {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
// EnabledResponse log the response object if it implements the EnableLogger
|
||||
// interface.
|
||||
func EnabledResponse(rw http.ResponseWriter, v interface{}) {
|
||||
type enableLogger interface {
|
||||
ToLog() (interface{}, error)
|
||||
}
|
||||
|
||||
if el, ok := v.(enableLogger); ok {
|
||||
out, err := el.ToLog()
|
||||
if err != nil {
|
||||
Error(rw, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if rl, ok := rw.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"response": out,
|
||||
})
|
||||
} else {
|
||||
log.Println(out)
|
||||
}
|
||||
}
|
||||
}
|
44
api/log/log_test.go
Normal file
44
api/log/log_test.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
func TestError(t *testing.T) {
|
||||
theError := errors.New("the error")
|
||||
|
||||
type args struct {
|
||||
rw http.ResponseWriter
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
withFields bool
|
||||
}{
|
||||
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
|
||||
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
Error(tt.args.rw, tt.args.err)
|
||||
if tt.withFields {
|
||||
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
|
||||
fields := rl.Fields()
|
||||
if !reflect.DeepEqual(fields["error"], theError) {
|
||||
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
|
||||
}
|
||||
} else {
|
||||
t.Error("ResponseWriter does not implement logging.ResponseLogger")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
31
api/read/read.go
Normal file
31
api/read/read.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
// Package read implements request object readers.
|
||||
package read
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// JSON reads JSON from the request body and stores it in the value
|
||||
// pointed by v.
|
||||
func JSON(r io.Reader, v interface{}) error {
|
||||
if err := json.NewDecoder(r).Decode(v); err != nil {
|
||||
return errs.BadRequestErr(err, "error decoding json")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProtoJSON reads JSON from the request body and stores it in the value
|
||||
// pointed by v.
|
||||
func ProtoJSON(r io.Reader, m proto.Message) error {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return errs.BadRequestErr(err, "error reading request body")
|
||||
}
|
||||
return protojson.Unmarshal(data, m)
|
||||
}
|
46
api/read/read_test.go
Normal file
46
api/read/read_test.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package read
|
||||
|
||||
import (
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
type args struct {
|
||||
r io.Reader
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
|
||||
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := JSON(tt.args.r, &tt.args.v)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
e, ok := err.(*errs.Error)
|
||||
if ok {
|
||||
if code := e.StatusCode(); code != 400 {
|
||||
t.Errorf("error.StatusCode() = %v, wants 400", code)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("error type = %T, wants *Error", err)
|
||||
}
|
||||
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
|
||||
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@ package api
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
|
@ -32,7 +33,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
var body RekeyRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
|
26
api/renew.go
26
api/renew.go
|
@ -1,20 +1,28 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
const (
|
||||
authorizationHeader = "Authorization"
|
||||
bearerScheme = "Bearer"
|
||||
)
|
||||
|
||||
// Renew uses the information of certificate in the TLS connection to create a
|
||||
// new one.
|
||||
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
|
||||
WriteError(w, errs.BadRequest("missing client certificate"))
|
||||
cert, err := h.getPeerCertificate(r)
|
||||
if err != nil {
|
||||
WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0])
|
||||
certChain, err := h.Authority.Renew(cert)
|
||||
if err != nil {
|
||||
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
|
||||
return
|
||||
|
@ -33,3 +41,15 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
|
|||
TLSOptions: h.Authority.GetTLSOptions(),
|
||||
}, http.StatusCreated)
|
||||
}
|
||||
|
||||
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
return r.TLS.PeerCertificates[0], nil
|
||||
}
|
||||
if s := r.Header.Get(authorizationHeader); s != "" {
|
||||
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
|
||||
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
|
||||
}
|
||||
}
|
||||
return nil, errs.BadRequest("missing client certificate")
|
||||
}
|
||||
|
|
|
@ -4,11 +4,13 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"golang.org/x/crypto/ocsp"
|
||||
)
|
||||
|
||||
// RevokeResponse is the response object that returns the health of the server.
|
||||
|
@ -48,7 +50,7 @@ func (r *RevokeRequest) Validate() (err error) {
|
|||
// TODO: Add CRL and OCSP support.
|
||||
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body RevokeRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
@ -49,7 +50,7 @@ type SignResponse struct {
|
|||
// information in the certificate request.
|
||||
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SignRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
|
12
api/ssh.go
12
api/ssh.go
|
@ -9,12 +9,14 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// SSHAuthority is the interface implemented by a SSH CA authority.
|
||||
|
@ -249,7 +251,7 @@ type SSHBastionResponse struct {
|
|||
// the request.
|
||||
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHSignRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
@ -393,7 +395,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
|
|||
// and servers.
|
||||
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHConfigRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
@ -425,7 +427,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
|
|||
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
|
||||
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHCheckPrincipalRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
@ -464,7 +466,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
|
|||
// SSHBastion provides returns the bastion configured if any.
|
||||
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHBastionRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -4,9 +4,11 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// SSHRekeyRequest is the request body of an SSH certificate request.
|
||||
|
@ -38,7 +40,7 @@ type SSHRekeyResponse struct {
|
|||
// the request.
|
||||
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRekeyRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
@ -36,7 +38,7 @@ type SSHRenewResponse struct {
|
|||
// the request.
|
||||
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRenewRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -3,11 +3,13 @@ package api
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"golang.org/x/crypto/ocsp"
|
||||
)
|
||||
|
||||
// SSHRevokeResponse is the response object that returns the health of the server.
|
||||
|
@ -47,7 +49,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
|
|||
// NOTE: currently only Passive revocation is supported.
|
||||
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
|
||||
var body SSHRevokeRequest
|
||||
if err := ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
WriteError(w, errs.BadRequestErr(err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -18,12 +18,13 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"github.com/smallstep/certificates/templates"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
117
api/utils.go
117
api/utils.go
|
@ -2,14 +2,15 @@ package api
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
|
||||
"net/http"
|
||||
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/smallstep/certificates/api/log"
|
||||
)
|
||||
|
||||
// EnableLogger is an interface that enables response logging for an object.
|
||||
|
@ -17,38 +18,6 @@ type EnableLogger interface {
|
|||
ToLog() (interface{}, error)
|
||||
}
|
||||
|
||||
// LogError adds to the response writer the given error if it implements
|
||||
// logging.ResponseLogger. If it does not implement it, then writes the error
|
||||
// using the log package.
|
||||
func LogError(rw http.ResponseWriter, err error) {
|
||||
if rl, ok := rw.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"error": err,
|
||||
})
|
||||
} else {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
||||
|
||||
// LogEnabledResponse log the response object if it implements the EnableLogger
|
||||
// interface.
|
||||
func LogEnabledResponse(rw http.ResponseWriter, v interface{}) {
|
||||
if el, ok := v.(EnableLogger); ok {
|
||||
out, err := el.ToLog()
|
||||
if err != nil {
|
||||
LogError(rw, err)
|
||||
return
|
||||
}
|
||||
if rl, ok := rw.(logging.ResponseLogger); ok {
|
||||
rl.WithFields(map[string]interface{}{
|
||||
"response": out,
|
||||
})
|
||||
} else {
|
||||
log.Println(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// JSON writes the passed value into the http.ResponseWriter.
|
||||
func JSON(w http.ResponseWriter, v interface{}) {
|
||||
JSONStatus(w, v, http.StatusOK)
|
||||
|
@ -60,10 +29,19 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
|
|||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
if err := json.NewEncoder(w).Encode(v); err != nil {
|
||||
LogError(w, err)
|
||||
log.Error(w, err)
|
||||
|
||||
return
|
||||
}
|
||||
LogEnabledResponse(w, v)
|
||||
|
||||
log.EnabledResponse(w, v)
|
||||
}
|
||||
|
||||
// JSONNotFound writes a HTTP Not Found response with empty body.
|
||||
func JSONNotFound(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
log.EnabledResponse(w, nil)
|
||||
}
|
||||
|
||||
// ProtoJSON writes the passed value into the http.ResponseWriter.
|
||||
|
@ -79,31 +57,62 @@ func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
|
|||
|
||||
b, err := protojson.Marshal(m)
|
||||
if err != nil {
|
||||
LogError(w, err)
|
||||
log.Error(w, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := w.Write(b); err != nil {
|
||||
LogError(w, err)
|
||||
log.Error(w, err)
|
||||
|
||||
return
|
||||
}
|
||||
//LogEnabledResponse(w, v)
|
||||
|
||||
// log.EnabledResponse(w, v)
|
||||
}
|
||||
|
||||
// ReadJSON reads JSON from the request body and stores it in the value
|
||||
// pointed by v.
|
||||
func ReadJSON(r io.Reader, v interface{}) error {
|
||||
if err := json.NewDecoder(r).Decode(v); err != nil {
|
||||
return errs.BadRequestErr(err, "error decoding json")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadProtoJSON reads JSON from the request body and stores it in the value
|
||||
// pointed by v.
|
||||
func ReadProtoJSON(r io.Reader, m proto.Message) error {
|
||||
// ReadProtoJSONWithCheck reads JSON from the request body and stores it in the value
|
||||
// pointed by v. TODO(hs): move this to and integrate with render package.
|
||||
func ReadProtoJSONWithCheck(w http.ResponseWriter, r io.Reader, m proto.Message) bool {
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return errs.BadRequestErr(err, "error reading request body")
|
||||
var wrapper = struct {
|
||||
Status int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
Status: http.StatusBadRequest,
|
||||
Message: err.Error(),
|
||||
}
|
||||
return protojson.Unmarshal(data, m)
|
||||
errData, err := json.Marshal(wrapper)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write(errData)
|
||||
return false
|
||||
}
|
||||
if err := protojson.Unmarshal(data, m); err != nil {
|
||||
if errors.Is(err, proto.Error) {
|
||||
var wrapper = struct {
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
Message: err.Error(),
|
||||
}
|
||||
errData, err := json.Marshal(wrapper)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write(errData)
|
||||
return false
|
||||
}
|
||||
|
||||
// fallback to the default error writer
|
||||
WriteError(w, err)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -1,50 +1,13 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/logging"
|
||||
)
|
||||
|
||||
func TestLogError(t *testing.T) {
|
||||
theError := errors.New("the error")
|
||||
type args struct {
|
||||
rw http.ResponseWriter
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
withFields bool
|
||||
}{
|
||||
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
|
||||
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
LogError(tt.args.rw, tt.args.err)
|
||||
if tt.withFields {
|
||||
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
|
||||
fields := rl.Fields()
|
||||
if !reflect.DeepEqual(fields["error"], theError) {
|
||||
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
|
||||
}
|
||||
} else {
|
||||
t.Error("ResponseWriter does not implement logging.ResponseLogger")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSON(t *testing.T) {
|
||||
type args struct {
|
||||
rw http.ResponseWriter
|
||||
|
@ -88,38 +51,3 @@ func TestJSON(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadJSON(t *testing.T) {
|
||||
type args struct {
|
||||
r io.Reader
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
|
||||
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ReadJSON(tt.args.r, &tt.args.v)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ReadJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr {
|
||||
e, ok := err.(*errs.Error)
|
||||
if ok {
|
||||
if code := e.StatusCode(); code != 400 {
|
||||
t.Errorf("error.StatusCode() = %v, wants 400", code)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("error type = %T, wants *Error", err)
|
||||
}
|
||||
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
|
||||
t.Errorf("ReadJSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ import (
|
|||
|
||||
const (
|
||||
// provisionerContextKey provisioner key
|
||||
provisionerContextKey = ContextKey("provisioner")
|
||||
provisionerContextKey = admin.ContextKey("provisioner")
|
||||
)
|
||||
|
||||
// CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests
|
||||
|
|
|
@ -5,10 +5,13 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
type adminAuthority interface {
|
||||
|
@ -25,6 +28,10 @@ type adminAuthority interface {
|
|||
LoadProvisionerByID(id string) (provisioner.Interface, error)
|
||||
UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error
|
||||
RemoveProvisioner(ctx context.Context, id string) error
|
||||
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
|
||||
CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
|
||||
UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
|
||||
RemoveAuthorityPolicy(ctx context.Context) error
|
||||
}
|
||||
|
||||
// CreateAdminRequest represents the body for a CreateAdmin request.
|
||||
|
@ -112,7 +119,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
|
|||
// CreateAdmin creates a new admin.
|
||||
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body CreateAdminRequest
|
||||
if err := api.ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
@ -156,7 +163,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
|
|||
// UpdateAdmin updates an existing admin.
|
||||
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
|
||||
var body UpdateAdminRequest
|
||||
if err := api.ReadJSON(r.Body, &body); err != nil {
|
||||
if err := read.JSON(r.Body, &body); err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -14,11 +14,13 @@ import (
|
|||
"github.com/go-chi/chi"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"go.step.sm/linkedca"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type mockAdminAuthority struct {
|
||||
|
@ -37,6 +39,11 @@ type mockAdminAuthority struct {
|
|||
MockLoadProvisionerByID func(id string) (provisioner.Interface, error)
|
||||
MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error
|
||||
MockRemoveProvisioner func(ctx context.Context, id string) error
|
||||
|
||||
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
|
||||
MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) (*linkedca.Policy, error)
|
||||
MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
|
||||
MockRemoveAuthorityPolicy func(ctx context.Context) error
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) IsAdminAPIEnabled() bool {
|
||||
|
@ -130,6 +137,22 @@ func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) e
|
|||
return m.MockErr
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
|
||||
return nil, errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
|
||||
return nil, errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error {
|
||||
return errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func TestCreateAdminRequest_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
Subject string
|
||||
|
|
|
@ -12,20 +12,23 @@ type Handler struct {
|
|||
auth adminAuthority
|
||||
acmeDB acme.DB
|
||||
acmeResponder acmeAdminResponderInterface
|
||||
policyResponder policyAdminResponderInterface
|
||||
}
|
||||
|
||||
// NewHandler returns a new Authority Config Handler.
|
||||
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler {
|
||||
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) api.RouterHandler {
|
||||
return &Handler{
|
||||
auth: auth,
|
||||
adminDB: adminDB,
|
||||
acmeDB: acmeDB,
|
||||
acmeResponder: acmeResponder,
|
||||
policyResponder: policyResponder,
|
||||
}
|
||||
}
|
||||
|
||||
// Route traffic and implement the Router interface.
|
||||
func (h *Handler) Route(r api.Router) {
|
||||
|
||||
authnz := func(next nextHTTP) nextHTTP {
|
||||
return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next))
|
||||
}
|
||||
|
@ -34,6 +37,14 @@ func (h *Handler) Route(r api.Router) {
|
|||
return h.requireEABEnabled(next)
|
||||
}
|
||||
|
||||
enabledInStandalone := func(next nextHTTP) nextHTTP {
|
||||
return h.checkAction(next, true)
|
||||
}
|
||||
|
||||
disabledInStandalone := func(next nextHTTP) nextHTTP {
|
||||
return h.checkAction(next, false)
|
||||
}
|
||||
|
||||
// Provisioners
|
||||
r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner))
|
||||
r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners))
|
||||
|
@ -53,4 +64,24 @@ func (h *Handler) Route(r api.Router) {
|
|||
r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys)))
|
||||
r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey)))
|
||||
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey)))
|
||||
|
||||
// Policy - Authority
|
||||
r.MethodFunc("GET", "/policy", authnz(enabledInStandalone(h.policyResponder.GetAuthorityPolicy)))
|
||||
r.MethodFunc("POST", "/policy", authnz(enabledInStandalone(h.policyResponder.CreateAuthorityPolicy)))
|
||||
r.MethodFunc("PUT", "/policy", authnz(enabledInStandalone(h.policyResponder.UpdateAuthorityPolicy)))
|
||||
r.MethodFunc("DELETE", "/policy", authnz(enabledInStandalone(h.policyResponder.DeleteAuthorityPolicy)))
|
||||
|
||||
// Policy - Provisioner
|
||||
//r.MethodFunc("GET", "/provisioners/{name}/policy", noauth(h.policyResponder.GetProvisionerPolicy))
|
||||
r.MethodFunc("GET", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.GetProvisionerPolicy)))
|
||||
r.MethodFunc("POST", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.CreateProvisionerPolicy)))
|
||||
r.MethodFunc("PUT", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.UpdateProvisionerPolicy)))
|
||||
r.MethodFunc("DELETE", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.DeleteProvisionerPolicy)))
|
||||
|
||||
// Policy - ACME Account
|
||||
// TODO: ensure we don't clash with eab; might want to change eab paths slightly (as long as we don't have it released completely; needs changes in adminClient too)
|
||||
r.MethodFunc("GET", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.GetACMEAccountPolicy)))
|
||||
r.MethodFunc("POST", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.CreateACMEAccountPolicy)))
|
||||
r.MethodFunc("PUT", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.UpdateACMEAccountPolicy)))
|
||||
r.MethodFunc("DELETE", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.DeleteACMEAccountPolicy)))
|
||||
}
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/admin/db/nosql"
|
||||
)
|
||||
|
||||
type nextHTTP = func(http.ResponseWriter, *http.Request)
|
||||
|
@ -26,6 +28,7 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
|
|||
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
|
||||
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
tok := r.Header.Get("Authorization")
|
||||
if tok == "" {
|
||||
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
|
||||
|
@ -39,16 +42,30 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
|
|||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), adminContextKey, adm)
|
||||
ctx := linkedca.WithAdmin(r.Context(), adm)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
// ContextKey is the key type for storing and searching for ACME request
|
||||
// essentials in the context of a request.
|
||||
type ContextKey string
|
||||
// checkAction checks if an action is supported in standalone or not
|
||||
func (h *Handler) checkAction(next nextHTTP, supportedInStandalone bool) nextHTTP {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
const (
|
||||
// adminContextKey account key
|
||||
adminContextKey = ContextKey("admin")
|
||||
)
|
||||
// actions allowed in standalone mode are always supported
|
||||
if supportedInStandalone {
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// when not in standalone mode and using a nosql.DB backend,
|
||||
// actions are not supported
|
||||
if _, ok := h.adminDB.(*nosql.DB); ok {
|
||||
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType,
|
||||
"operation not supported in standalone mode"))
|
||||
return
|
||||
}
|
||||
|
||||
// continue to next http handler
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -152,7 +152,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||
req.Header["Authorization"] = []string{"token"}
|
||||
createdAt := time.Now()
|
||||
var deletedAt time.Time
|
||||
admin := &linkedca.Admin{
|
||||
adm := &linkedca.Admin{
|
||||
Id: "adminID",
|
||||
AuthorityId: "authorityID",
|
||||
Subject: "admin",
|
||||
|
@ -164,20 +164,15 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
|
|||
auth := &mockAdminAuthority{
|
||||
MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) {
|
||||
assert.Equals(t, "token", token)
|
||||
return admin, nil
|
||||
return adm, nil
|
||||
},
|
||||
}
|
||||
next := func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin
|
||||
adm, ok := a.(*linkedca.Admin)
|
||||
if !ok {
|
||||
t.Errorf("expected *linkedca.Admin; got %T", a)
|
||||
return
|
||||
}
|
||||
adm := linkedca.AdminFromContext(ctx) // verifying that the context now has a linkedca.Admin
|
||||
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
|
||||
if !cmp.Equal(admin, adm, opts...) {
|
||||
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...))
|
||||
if !cmp.Equal(adm, adm, opts...) {
|
||||
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...))
|
||||
}
|
||||
w.Write(nil) // mock response with status 200
|
||||
}
|
||||
|
|
326
authority/admin/api/policy.go
Normal file
326
authority/admin/api/policy.go
Normal file
|
@ -0,0 +1,326 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
)
|
||||
|
||||
type policyAdminResponderInterface interface {
|
||||
GetAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request)
|
||||
GetProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request)
|
||||
GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// PolicyAdminResponder is responsible for writing ACME admin responses
|
||||
type PolicyAdminResponder struct {
|
||||
auth adminAuthority
|
||||
adminDB admin.DB
|
||||
}
|
||||
|
||||
// NewACMEAdminResponder returns a new ACMEAdminResponder
|
||||
func NewPolicyAdminResponder(auth adminAuthority, adminDB admin.DB) *PolicyAdminResponder {
|
||||
return &PolicyAdminResponder{
|
||||
auth: auth,
|
||||
adminDB: adminDB,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthorityPolicy handles the GET /admin/authority/policy request
|
||||
func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
policy, err := par.auth.GetAuthorityPolicy(r.Context())
|
||||
if ae, ok := err.(*admin.Error); ok {
|
||||
if !ae.IsType(admin.ErrorNotFoundType) {
|
||||
api.WriteError(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if policy == nil {
|
||||
api.JSONNotFound(w)
|
||||
return
|
||||
}
|
||||
|
||||
api.ProtoJSONStatus(w, policy, http.StatusOK)
|
||||
}
|
||||
|
||||
// CreateAuthorityPolicy handles the POST /admin/authority/policy request
|
||||
func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
policy, err := par.auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
shouldWriteError := false
|
||||
if ae, ok := err.(*admin.Error); ok {
|
||||
shouldWriteError = !ae.IsType(admin.ErrorNotFoundType)
|
||||
}
|
||||
|
||||
if shouldWriteError {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if policy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "authority already has a policy")
|
||||
adminErr.Status = http.StatusConflict
|
||||
api.WriteError(w, adminErr)
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if !api.ReadProtoJSONWithCheck(w, r.Body, newPolicy) {
|
||||
return
|
||||
}
|
||||
|
||||
adm := linkedca.AdminFromContext(ctx)
|
||||
|
||||
var createdPolicy *linkedca.Policy
|
||||
if createdPolicy, err = par.auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
api.JSONStatus(w, createdPolicy, http.StatusCreated)
|
||||
}
|
||||
|
||||
// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request
|
||||
func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
policy, err := par.auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
shouldWriteError := false
|
||||
if ae, ok := err.(*admin.Error); ok {
|
||||
shouldWriteError = !ae.IsType(admin.ErrorNotFoundType)
|
||||
}
|
||||
|
||||
if shouldWriteError {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
if policy == nil {
|
||||
api.JSONNotFound(w)
|
||||
return
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
adm := linkedca.AdminFromContext(ctx)
|
||||
|
||||
var updatedPolicy *linkedca.Policy
|
||||
if updatedPolicy, err = par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
|
||||
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
api.ProtoJSONStatus(w, updatedPolicy, http.StatusOK)
|
||||
}
|
||||
|
||||
// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request
|
||||
func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
policy, err := par.auth.GetAuthorityPolicy(ctx)
|
||||
|
||||
if ae, ok := err.(*admin.Error); ok {
|
||||
if !ae.IsType(admin.ErrorNotFoundType) {
|
||||
api.WriteError(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if policy == nil {
|
||||
api.JSONNotFound(w)
|
||||
return
|
||||
}
|
||||
|
||||
err = par.auth.RemoveAuthorityPolicy(ctx)
|
||||
if err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error deleting authority policy"))
|
||||
return
|
||||
}
|
||||
|
||||
api.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
|
||||
}
|
||||
|
||||
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
|
||||
func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: move getting provisioner to middleware?
|
||||
ctx := r.Context()
|
||||
name := chi.URLParam(r, "name")
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
if p, err = par.auth.LoadProvisionerByName(name); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := par.adminDB.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
policy := prov.GetPolicy()
|
||||
if policy == nil {
|
||||
api.JSONNotFound(w)
|
||||
return
|
||||
}
|
||||
|
||||
api.ProtoJSONStatus(w, policy, http.StatusOK)
|
||||
}
|
||||
|
||||
// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request
|
||||
func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
name := chi.URLParam(r, "name")
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
if p, err = par.auth.LoadProvisionerByName(name); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := par.adminDB.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
policy := prov.GetPolicy()
|
||||
if policy != nil {
|
||||
adminErr := admin.NewError(admin.ErrorBadRequestType, "provisioner %s already has a policy", name)
|
||||
adminErr.Status = http.StatusConflict
|
||||
api.WriteError(w, adminErr)
|
||||
}
|
||||
|
||||
var newPolicy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov.Policy = newPolicy
|
||||
|
||||
err = par.auth.UpdateProvisioner(ctx, prov)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
api.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
|
||||
}
|
||||
|
||||
// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request
|
||||
func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
name := chi.URLParam(r, "name")
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
if p, err = par.auth.LoadProvisionerByName(name); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := par.adminDB.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var policy = new(linkedca.Policy)
|
||||
if err := read.ProtoJSON(r.Body, policy); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
prov.Policy = policy
|
||||
err = par.auth.UpdateProvisioner(ctx, prov)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
api.ProtoJSONStatus(w, policy, http.StatusOK)
|
||||
}
|
||||
|
||||
// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request
|
||||
func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
ctx := r.Context()
|
||||
name := chi.URLParam(r, "name")
|
||||
var (
|
||||
p provisioner.Interface
|
||||
err error
|
||||
)
|
||||
if p, err = par.auth.LoadProvisionerByName(name); err != nil {
|
||||
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
|
||||
return
|
||||
}
|
||||
|
||||
prov, err := par.adminDB.GetProvisioner(ctx, p.GetID())
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if prov.Policy == nil {
|
||||
api.JSONNotFound(w)
|
||||
return
|
||||
}
|
||||
|
||||
// remove the policy
|
||||
prov.Policy = nil
|
||||
|
||||
err = par.auth.UpdateProvisioner(ctx, prov)
|
||||
if err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
api.JSON(w, &DeleteResponse{Status: "ok"})
|
||||
}
|
||||
|
||||
func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
api.JSON(w, "ok")
|
||||
}
|
||||
|
||||
func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
api.JSON(w, "ok")
|
||||
}
|
||||
|
||||
func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
api.JSON(w, "ok")
|
||||
}
|
||||
|
||||
func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
api.JSON(w, "ok")
|
||||
}
|
|
@ -4,12 +4,14 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
// GetProvisionersResponse is the type for GET /admin/provisioners responses.
|
||||
|
@ -72,7 +74,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
|
|||
// CreateProvisioner creates a new prov.
|
||||
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var prov = new(linkedca.Provisioner)
|
||||
if err := api.ReadProtoJSON(r.Body, prov); err != nil {
|
||||
if err := read.ProtoJSON(r.Body, prov); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
@ -122,7 +124,7 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
|
|||
// UpdateProvisioner updates an existing prov.
|
||||
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
|
||||
var nu = new(linkedca.Provisioner)
|
||||
if err := api.ReadProtoJSON(r.Body, nu); err != nil {
|
||||
if err := read.ProtoJSON(r.Body, nu); err != nil {
|
||||
api.WriteError(w, err)
|
||||
return
|
||||
}
|
||||
|
|
10
authority/admin/context.go
Normal file
10
authority/admin/context.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package admin
|
||||
|
||||
// ContextKey is the key type for storing and searching for
|
||||
// Admin API objects in request contexts.
|
||||
type ContextKey string
|
||||
|
||||
const (
|
||||
// AdminContextKey account key
|
||||
AdminContextKey = ContextKey("admin")
|
||||
)
|
|
@ -69,6 +69,11 @@ type DB interface {
|
|||
GetAdmins(ctx context.Context) ([]*linkedca.Admin, error)
|
||||
UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error
|
||||
DeleteAdmin(ctx context.Context, id string) error
|
||||
|
||||
CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
|
||||
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
|
||||
UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
|
||||
DeleteAuthorityPolicy(ctx context.Context) error
|
||||
}
|
||||
|
||||
// MockDB is an implementation of the DB interface that should only be used as
|
||||
|
@ -86,6 +91,11 @@ type MockDB struct {
|
|||
MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error
|
||||
MockDeleteAdmin func(ctx context.Context, id string) error
|
||||
|
||||
MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
|
||||
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
|
||||
MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
|
||||
MockDeleteAuthorityPolicy func(ctx context.Context) error
|
||||
|
||||
MockError error
|
||||
MockRet1 interface{}
|
||||
}
|
||||
|
@ -179,3 +189,30 @@ func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error {
|
|||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
func (m *MockDB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
if m.MockCreateAuthorityPolicy != nil {
|
||||
return m.MockCreateAuthorityPolicy(ctx, policy)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
func (m *MockDB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
if m.MockGetAuthorityPolicy != nil {
|
||||
return m.MockGetAuthorityPolicy(ctx)
|
||||
}
|
||||
return m.MockRet1.(*linkedca.Policy), m.MockError
|
||||
}
|
||||
|
||||
func (m *MockDB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
if m.MockUpdateAuthorityPolicy != nil {
|
||||
return m.MockUpdateAuthorityPolicy(ctx, policy)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
||||
func (m *MockDB) DeleteAuthorityPolicy(ctx context.Context) error {
|
||||
if m.MockDeleteAuthorityPolicy != nil {
|
||||
return m.MockDeleteAuthorityPolicy(ctx)
|
||||
}
|
||||
return m.MockError
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
var (
|
||||
adminsTable = []byte("admins")
|
||||
provisionersTable = []byte("provisioners")
|
||||
authorityPoliciesTable = []byte("authority_policies")
|
||||
)
|
||||
|
||||
// DB is a struct that implements the AdminDB interface.
|
||||
|
@ -23,7 +24,7 @@ type DB struct {
|
|||
|
||||
// New configures and returns a new Authority DB backend implemented using a nosql DB.
|
||||
func New(db nosqlDB.DB, authorityID string) (*DB, error) {
|
||||
tables := [][]byte{adminsTable, provisionersTable}
|
||||
tables := [][]byte{adminsTable, provisionersTable, authorityPoliciesTable}
|
||||
for _, b := range tables {
|
||||
if err := db.CreateTable(b); err != nil {
|
||||
return nil, errors.Wrapf(err, "error creating table %s",
|
||||
|
|
144
authority/admin/db/nosql/policy.go
Normal file
144
authority/admin/db/nosql/policy.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
package nosql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/nosql"
|
||||
"go.step.sm/linkedca"
|
||||
)
|
||||
|
||||
type dbAuthorityPolicy struct {
|
||||
ID string `json:"id"`
|
||||
AuthorityID string `json:"authorityID"`
|
||||
Policy *linkedca.Policy `json:"policy"`
|
||||
}
|
||||
|
||||
func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy {
|
||||
return dbap.Policy
|
||||
}
|
||||
|
||||
func (dbap *dbAuthorityPolicy) clone() *dbAuthorityPolicy {
|
||||
u := *dbap
|
||||
return &u
|
||||
}
|
||||
|
||||
func (db *DB) getDBAuthorityPolicyBytes(ctx context.Context, authorityID string) ([]byte, error) {
|
||||
data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID))
|
||||
if nosql.IsErrNotFound(err) {
|
||||
return nil, admin.NewError(admin.ErrorNotFoundType, "policy %s not found", authorityID)
|
||||
} else if err != nil {
|
||||
return nil, errors.Wrapf(err, "error loading admin %s", authorityID)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (db *DB) unmarshalDBAuthorityPolicy(data []byte, authorityID string) (*dbAuthorityPolicy, error) {
|
||||
var dba = new(dbAuthorityPolicy)
|
||||
if err := json.Unmarshal(data, dba); err != nil {
|
||||
return nil, errors.Wrapf(err, "error unmarshaling admin %s into dbAdmin", authorityID)
|
||||
}
|
||||
// if !dba.DeletedAt.IsZero() {
|
||||
// return nil, admin.NewError(admin.ErrorDeletedType, "admin %s is deleted", authorityID)
|
||||
// }
|
||||
if dba.AuthorityID != db.authorityID {
|
||||
return nil, admin.NewError(admin.ErrorAuthorityMismatchType,
|
||||
"admin %s is not owned by authority %s", dba.ID, db.authorityID)
|
||||
}
|
||||
return dba, nil
|
||||
}
|
||||
|
||||
func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*dbAuthorityPolicy, error) {
|
||||
data, err := db.getDBAuthorityPolicyBytes(ctx, authorityID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbap, err := db.unmarshalDBAuthorityPolicy(data, authorityID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dbap, nil
|
||||
}
|
||||
|
||||
// func (db *DB) unmarshalAuthorityPolicy(data []byte, authorityID string) (*linkedca.Policy, error) {
|
||||
// dbap, err := db.unmarshalDBAuthorityPolicy(data, authorityID)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// return dbap.convert(), nil
|
||||
// }
|
||||
|
||||
func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
|
||||
dbap := &dbAuthorityPolicy{
|
||||
ID: db.authorityID,
|
||||
AuthorityID: db.authorityID,
|
||||
Policy: policy,
|
||||
}
|
||||
|
||||
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable)
|
||||
}
|
||||
|
||||
func (db *DB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
// policy := &linkedca.Policy{
|
||||
// X509: &linkedca.X509Policy{
|
||||
// Allow: &linkedca.X509Names{
|
||||
// Dns: []string{".localhost"},
|
||||
// },
|
||||
// Deny: &linkedca.X509Names{
|
||||
// Dns: []string{"denied.localhost"},
|
||||
// },
|
||||
// },
|
||||
// Ssh: &linkedca.SSHPolicy{
|
||||
// User: &linkedca.SSHUserPolicy{
|
||||
// Allow: &linkedca.SSHUserNames{},
|
||||
// Deny: &linkedca.SSHUserNames{},
|
||||
// },
|
||||
// Host: &linkedca.SSHHostPolicy{
|
||||
// Allow: &linkedca.SSHHostNames{},
|
||||
// Deny: &linkedca.SSHHostNames{},
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
|
||||
dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dbap.convert(), nil
|
||||
}
|
||||
|
||||
func (db *DB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbap := &dbAuthorityPolicy{
|
||||
ID: db.authorityID,
|
||||
AuthorityID: db.authorityID,
|
||||
Policy: policy,
|
||||
}
|
||||
|
||||
return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable)
|
||||
}
|
||||
|
||||
func (db *DB) DeleteAuthorityPolicy(ctx context.Context) error {
|
||||
dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
old := dbap.clone()
|
||||
|
||||
dbap.Policy = nil
|
||||
return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable)
|
||||
}
|
|
@ -19,6 +19,7 @@ type dbProvisioner struct {
|
|||
Type linkedca.Provisioner_Type `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Claims *linkedca.Claims `json:"claims"`
|
||||
Policy *linkedca.Policy `json:"policy"`
|
||||
Details []byte `json:"details"`
|
||||
X509Template *linkedca.Template `json:"x509Template"`
|
||||
SSHTemplate *linkedca.Template `json:"sshTemplate"`
|
||||
|
@ -43,6 +44,7 @@ func (dbp *dbProvisioner) convert2linkedca() (*linkedca.Provisioner, error) {
|
|||
Type: dbp.Type,
|
||||
Name: dbp.Name,
|
||||
Claims: dbp.Claims,
|
||||
Policy: dbp.Policy,
|
||||
Details: details,
|
||||
X509Template: dbp.X509Template,
|
||||
SshTemplate: dbp.SSHTemplate,
|
||||
|
@ -160,6 +162,7 @@ func (db *DB) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner)
|
|||
Type: prov.Type,
|
||||
Name: prov.Name,
|
||||
Claims: prov.Claims,
|
||||
Policy: prov.Policy,
|
||||
Details: details,
|
||||
X509Template: prov.X509Template,
|
||||
SSHTemplate: prov.SshTemplate,
|
||||
|
@ -187,6 +190,7 @@ func (db *DB) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner)
|
|||
}
|
||||
nu.Name = prov.Name
|
||||
nu.Claims = prov.Claims
|
||||
nu.Policy = prov.Policy
|
||||
nu.Details, err = json.Marshal(prov.Details.GetData())
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error marshaling details when updating provisioner %s", prov.Name)
|
||||
|
|
|
@ -78,7 +78,7 @@ func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool
|
|||
}
|
||||
|
||||
// Store adds an admin to the collection and enforces the uniqueness of
|
||||
// admin IDs and amdin subject <-> provisioner name combos.
|
||||
// admin IDs and admin subject <-> provisioner name combos.
|
||||
func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error {
|
||||
// Input validation.
|
||||
if adm.ProvisionerId != prov.GetID() {
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
|
||||
"github.com/smallstep/certificates/authority/administrator"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/cas"
|
||||
casapi "github.com/smallstep/certificates/cas/apiv1"
|
||||
|
@ -74,6 +75,13 @@ type Authority struct {
|
|||
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
|
||||
sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)
|
||||
getIdentityFunc provisioner.GetIdentityFunc
|
||||
authorizeRenewFunc provisioner.AuthorizeRenewFunc
|
||||
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
|
||||
|
||||
// Policy engines
|
||||
x509Policy policy.X509Policy
|
||||
sshUserPolicy policy.UserPolicy
|
||||
sshHostPolicy policy.HostPolicy
|
||||
|
||||
adminMutex sync.RWMutex
|
||||
}
|
||||
|
@ -199,6 +207,51 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
|
|||
a.provisioners = provClxn
|
||||
a.config.AuthorityConfig.Admins = adminList
|
||||
a.admins = adminClxn
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reloadPolicyEngines reloads x509 and SSH policy engines using
|
||||
// configuration stored in the DB or from the configuration file.
|
||||
func (a *Authority) reloadPolicyEngines(ctx context.Context) error {
|
||||
var (
|
||||
err error
|
||||
policyOptions *policy.Options
|
||||
)
|
||||
if a.config.AuthorityConfig.EnableAdmin {
|
||||
linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx)
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error getting policy to (re)load policy engines")
|
||||
}
|
||||
policyOptions = policyToCertificates(linkedPolicy)
|
||||
} else {
|
||||
policyOptions = a.config.AuthorityConfig.Policy
|
||||
}
|
||||
|
||||
// if no new or updated policy option is set, clear policy engines that (may have)
|
||||
// been configured before and return early
|
||||
if policyOptions == nil {
|
||||
a.x509Policy = nil
|
||||
a.sshHostPolicy = nil
|
||||
a.sshUserPolicy = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize the x509 allow/deny policy engine
|
||||
if a.x509Policy, err = policy.NewX509PolicyEngine(policyOptions.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// // Initialize the SSH allow/deny policy engine for host certificates
|
||||
if a.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(policyOptions.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// // Initialize the SSH allow/deny policy engine for user certificates
|
||||
if a.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(policyOptions.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -527,6 +580,11 @@ func (a *Authority) init() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Load x509 and SSH Policy Engines
|
||||
if err := a.reloadPolicyEngines(context.Background()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure templates, currently only ssh templates are supported.
|
||||
if a.sshCAHostCertSignKey != nil || a.sshCAUserCertSignKey != nil {
|
||||
a.templates = a.config.Templates
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -276,6 +277,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
|
|||
func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
||||
serial := cert.SerialNumber.String()
|
||||
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
|
||||
|
||||
isRevoked, err := a.IsRevoked(serial)
|
||||
if err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
|
||||
|
@ -283,7 +285,6 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
|
|||
if isRevoked {
|
||||
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
|
||||
}
|
||||
|
||||
p, ok := a.provisioners.LoadByCertificate(cert)
|
||||
if !ok {
|
||||
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
|
||||
|
@ -371,3 +372,80 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthorizeRenewToken validates the renew token and returns the leaf
|
||||
// certificate in the x5cInsecure header.
|
||||
func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
|
||||
var claims jose.Claims
|
||||
jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs)
|
||||
if err != nil {
|
||||
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
|
||||
}
|
||||
leaf := chain[0][0]
|
||||
if err := jwt.Claims(leaf.PublicKey, &claims); err != nil {
|
||||
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token"))
|
||||
}
|
||||
|
||||
p, ok := a.provisioners.LoadByCertificate(leaf)
|
||||
if !ok {
|
||||
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
|
||||
}
|
||||
if err := a.UseToken(ott, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||
Issuer: p.GetName(),
|
||||
Subject: leaf.Subject.CommonName,
|
||||
Time: time.Now().UTC(),
|
||||
}, time.Minute); err != nil {
|
||||
switch err {
|
||||
case jose.ErrInvalidIssuer:
|
||||
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid issuer claim (iss)"))
|
||||
case jose.ErrInvalidSubject:
|
||||
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid subject claim (sub)"))
|
||||
case jose.ErrNotValidYet:
|
||||
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token not valid yet (nbf)"))
|
||||
case jose.ErrExpired:
|
||||
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token is expired (exp)"))
|
||||
case jose.ErrIssuedInTheFuture:
|
||||
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token issued in the future (iat)"))
|
||||
default:
|
||||
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
|
||||
}
|
||||
}
|
||||
|
||||
audiences := a.config.GetAudiences().Renew
|
||||
if !matchesAudience(claims.Audience, audiences) {
|
||||
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
|
||||
}
|
||||
|
||||
return leaf, nil
|
||||
}
|
||||
|
||||
// matchesAudience returns true if A and B share at least one element.
|
||||
func matchesAudience(as, bs []string) bool {
|
||||
if len(bs) == 0 || len(as) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, b := range bs {
|
||||
for _, a := range as {
|
||||
if b == a || stripPort(a) == stripPort(b) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// stripPort attempts to strip the port from the given url. If parsing the url
|
||||
// produces errors it will just return the passed argument.
|
||||
func stripPort(rawurl string) string {
|
||||
u, err := url.Parse(rawurl)
|
||||
if err != nil {
|
||||
return rawurl
|
||||
}
|
||||
u.Host = u.Hostname()
|
||||
return u.String()
|
||||
}
|
||||
|
|
|
@ -3,11 +3,15 @@ package authority
|
|||
import (
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -20,6 +24,7 @@ import (
|
|||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
"go.step.sm/crypto/randutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
|
@ -753,6 +758,7 @@ func TestAuthority_Authorize(t *testing.T) {
|
|||
|
||||
func TestAuthority_authorizeRenew(t *testing.T) {
|
||||
fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt")
|
||||
fooCrt.NotAfter = time.Now().Add(time.Hour)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt")
|
||||
|
@ -822,7 +828,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
|
|||
return &authorizeTest{
|
||||
auth: a,
|
||||
cert: renewDisabledCrt,
|
||||
err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"),
|
||||
err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"),
|
||||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
|
@ -909,6 +915,7 @@ func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner.
|
|||
}
|
||||
|
||||
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
|
||||
now := time.Now()
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -917,6 +924,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if cert.ValidAfter == 0 {
|
||||
cert.ValidAfter = uint64(now.Unix())
|
||||
}
|
||||
if cert.ValidBefore == 0 {
|
||||
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
|
||||
}
|
||||
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -1003,6 +1016,23 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAuthority_authorizeSSHRenew(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
sshpop := func(a *Authority) (*ssh.Certificate, string) {
|
||||
p, ok := a.provisioners.Load("sshpop/sshpop")
|
||||
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
|
||||
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
|
||||
assert.FatalError(t, err)
|
||||
signer, ok := key.(crypto.Signer)
|
||||
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
|
||||
sshSigner, err := ssh.NewSignerFromSigner(signer)
|
||||
assert.FatalError(t, err)
|
||||
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
token, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
return cert, token
|
||||
}
|
||||
|
||||
a := testAuthority(t)
|
||||
|
||||
jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
|
||||
|
@ -1012,8 +1042,6 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
|
|||
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
|
||||
assert.FatalError(t, err)
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
validIssuer := "step-cli"
|
||||
|
||||
type authorizeTest struct {
|
||||
|
@ -1050,27 +1078,34 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
|
|||
code: http.StatusUnauthorized,
|
||||
}
|
||||
},
|
||||
"fail/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
|
||||
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
|
||||
return errs.Forbidden("forbidden")
|
||||
}))
|
||||
_, token := sshpop(aa)
|
||||
return &authorizeTest{
|
||||
auth: aa,
|
||||
token: token,
|
||||
err: errors.New("authority.authorizeSSHRenew: forbidden"),
|
||||
code: http.StatusForbidden,
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) *authorizeTest {
|
||||
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
|
||||
assert.FatalError(t, err)
|
||||
signer, ok := key.(crypto.Signer)
|
||||
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
|
||||
sshSigner, err := ssh.NewSignerFromSigner(signer)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
p, ok := a.provisioners.Load("sshpop/sshpop")
|
||||
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
|
||||
|
||||
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop",
|
||||
[]string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert))
|
||||
assert.FatalError(t, err)
|
||||
|
||||
cert, token := sshpop(a)
|
||||
return &authorizeTest{
|
||||
auth: a,
|
||||
token: tok,
|
||||
token: token,
|
||||
cert: cert,
|
||||
}
|
||||
},
|
||||
"ok/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
|
||||
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
|
||||
return nil
|
||||
}))
|
||||
cert, token := sshpop(aa)
|
||||
return &authorizeTest{
|
||||
auth: aa,
|
||||
token: token,
|
||||
cert: cert,
|
||||
}
|
||||
},
|
||||
|
@ -1290,3 +1325,283 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthority_AuthorizeRenewToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
type stepProvisionerASN1 struct {
|
||||
Type int
|
||||
Name []byte
|
||||
CredentialID []byte
|
||||
KeyValuePairs []string `asn1:"optional,omitempty"`
|
||||
}
|
||||
|
||||
_, signer, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
csr, err := x509util.CreateCertificateRequest("test.example.com", []string{"test.example.com"}, signer)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, otherSigner, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) {
|
||||
chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var x5c []string
|
||||
for _, c := range chain {
|
||||
x5c = append(x5c, base64.StdEncoding.EncodeToString(c.Raw))
|
||||
}
|
||||
|
||||
so := new(jose.SignerOptions)
|
||||
so.WithType("JWT")
|
||||
so.WithHeader("x5cInsecure", x5c)
|
||||
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: key}, so)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return s, chain[0]
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
a1 := testAuthority(t)
|
||||
t1, c1 := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
t2, c2 := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
IssuedAt: jose.NewNumericDate(now),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now.Add(-time.Hour)
|
||||
cert.NotAfter = now.Add(-time.Minute)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
badIssuer, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "bad-issuer",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
badSubject, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/renew"},
|
||||
Subject: "bad-subject",
|
||||
Issuer: "step-cli",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/sign"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/sign"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)),
|
||||
Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/sign"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
badAudience, _ := generateX5cToken(a1, signer, jose.Claims{
|
||||
Audience: []string{"https://example.com/1.0/sign"},
|
||||
Subject: "test.example.com",
|
||||
Issuer: "step-cli",
|
||||
NotBefore: jose.NewNumericDate(now),
|
||||
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
|
||||
cert.NotBefore = now
|
||||
cert.NotAfter = now.Add(time.Hour)
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
|
||||
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
|
||||
Value: b,
|
||||
})
|
||||
return nil
|
||||
}))
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
ott string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
authority *Authority
|
||||
args args
|
||||
want *x509.Certificate
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", a1, args{ctx, t1}, c1, false},
|
||||
{"ok expired cert", a1, args{ctx, t2}, c2, false},
|
||||
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
|
||||
{"fail token reuse", a1, args{ctx, t1}, nil, true},
|
||||
{"fail token signature", a1, args{ctx, badSigner}, nil, true},
|
||||
{"fail token provisioner", a1, args{ctx, badProvisioner}, nil, true},
|
||||
{"fail token iss", a1, args{ctx, badIssuer}, nil, true},
|
||||
{"fail token sub", a1, args{ctx, badSubject}, nil, true},
|
||||
{"fail token iat", a1, args{ctx, badNotBefore}, nil, true},
|
||||
{"fail token iat", a1, args{ctx, badExpiry}, nil, true},
|
||||
{"fail token iat", a1, args{ctx, badIssuedAt}, nil, true},
|
||||
{"fail token aud", a1, args{ctx, badAudience}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.authority.AuthorizeRenewToken(tt.args.ctx, tt.args.ott)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Authority.AuthorizeRenewToken() error = %+v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Authority.AuthorizeRenewToken() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
cas "github.com/smallstep/certificates/cas/apiv1"
|
||||
"github.com/smallstep/certificates/db"
|
||||
|
@ -26,6 +27,9 @@ var (
|
|||
DefaultBackdate = time.Minute
|
||||
// DefaultDisableRenewal disables renewals per provisioner.
|
||||
DefaultDisableRenewal = false
|
||||
// DefaultAllowRenewAfterExpiry allows renewals even if the certificate is
|
||||
// expired.
|
||||
DefaultAllowRenewAfterExpiry = false
|
||||
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally
|
||||
// for all provisioners.
|
||||
DefaultEnableSSHCA = false
|
||||
|
@ -35,7 +39,6 @@ var (
|
|||
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
|
||||
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DisableRenewal: &DefaultDisableRenewal,
|
||||
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
|
||||
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
|
||||
|
@ -43,6 +46,8 @@ var (
|
|||
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
|
||||
EnableSSHCA: &DefaultEnableSSHCA,
|
||||
DisableRenewal: &DefaultDisableRenewal,
|
||||
AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -90,6 +95,7 @@ type AuthConfig struct {
|
|||
Admins []*linkedca.Admin `json:"-"`
|
||||
Template *ASN1DN `json:"template,omitempty"`
|
||||
Claims *provisioner.Claims `json:"claims,omitempty"`
|
||||
Policy *policy.Options `json:"policy,omitempty"`
|
||||
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
|
||||
Backdate *provisioner.Duration `json:"backdate,omitempty"`
|
||||
EnableAdmin bool `json:"enableAdmin,omitempty"`
|
||||
|
@ -269,28 +275,32 @@ func (c *Config) GetAudiences() provisioner.Audiences {
|
|||
}
|
||||
|
||||
for _, name := range c.DNSNames {
|
||||
hostname := toHostname(name)
|
||||
audiences.Sign = append(audiences.Sign,
|
||||
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)),
|
||||
fmt.Sprintf("https://%s/sign", toHostname(name)),
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)),
|
||||
fmt.Sprintf("https://%s/ssh/sign", toHostname(name)))
|
||||
fmt.Sprintf("https://%s/1.0/sign", hostname),
|
||||
fmt.Sprintf("https://%s/sign", hostname),
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
|
||||
fmt.Sprintf("https://%s/ssh/sign", hostname))
|
||||
audiences.Renew = append(audiences.Renew,
|
||||
fmt.Sprintf("https://%s/1.0/renew", hostname),
|
||||
fmt.Sprintf("https://%s/renew", hostname))
|
||||
audiences.Revoke = append(audiences.Revoke,
|
||||
fmt.Sprintf("https://%s/1.0/revoke", toHostname(name)),
|
||||
fmt.Sprintf("https://%s/revoke", toHostname(name)))
|
||||
fmt.Sprintf("https://%s/1.0/revoke", hostname),
|
||||
fmt.Sprintf("https://%s/revoke", hostname))
|
||||
audiences.SSHSign = append(audiences.SSHSign,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)),
|
||||
fmt.Sprintf("https://%s/ssh/sign", toHostname(name)),
|
||||
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)),
|
||||
fmt.Sprintf("https://%s/sign", toHostname(name)))
|
||||
fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
|
||||
fmt.Sprintf("https://%s/ssh/sign", hostname),
|
||||
fmt.Sprintf("https://%s/1.0/sign", hostname),
|
||||
fmt.Sprintf("https://%s/sign", hostname))
|
||||
audiences.SSHRevoke = append(audiences.SSHRevoke,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/revoke", toHostname(name)),
|
||||
fmt.Sprintf("https://%s/ssh/revoke", toHostname(name)))
|
||||
fmt.Sprintf("https://%s/1.0/ssh/revoke", hostname),
|
||||
fmt.Sprintf("https://%s/ssh/revoke", hostname))
|
||||
audiences.SSHRenew = append(audiences.SSHRenew,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/renew", toHostname(name)),
|
||||
fmt.Sprintf("https://%s/ssh/renew", toHostname(name)))
|
||||
fmt.Sprintf("https://%s/1.0/ssh/renew", hostname),
|
||||
fmt.Sprintf("https://%s/ssh/renew", hostname))
|
||||
audiences.SSHRekey = append(audiences.SSHRekey,
|
||||
fmt.Sprintf("https://%s/1.0/ssh/rekey", toHostname(name)),
|
||||
fmt.Sprintf("https://%s/ssh/rekey", toHostname(name)))
|
||||
fmt.Sprintf("https://%s/1.0/ssh/rekey", hostname),
|
||||
fmt.Sprintf("https://%s/ssh/rekey", hostname))
|
||||
}
|
||||
|
||||
return audiences
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/db"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
|
@ -34,6 +35,9 @@ type linkedCaClient struct {
|
|||
authorityID string
|
||||
}
|
||||
|
||||
// interface guard
|
||||
var _ admin.DB = (*linkedCaClient)(nil)
|
||||
|
||||
type linkedCAClaims struct {
|
||||
jose.Claims
|
||||
SANs []string `json:"sans"`
|
||||
|
@ -310,6 +314,22 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
|
|||
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
return errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
return nil, errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
|
||||
return errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error {
|
||||
return errors.New("not implemented yet")
|
||||
}
|
||||
|
||||
func serializeCertificate(crt *x509.Certificate) string {
|
||||
if crt == nil {
|
||||
return ""
|
||||
|
|
|
@ -92,6 +92,24 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e
|
|||
}
|
||||
}
|
||||
|
||||
// WithAuthorizeRenewFunc sets a custom function that authorizes the renewal of
|
||||
// an X.509 certificate.
|
||||
func WithAuthorizeRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error) Option {
|
||||
return func(a *Authority) error {
|
||||
a.authorizeRenewFunc = fn
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuthorizeSSHRenewFunc sets a custom function that authorizes the renewal
|
||||
// of a SSH certificate.
|
||||
func WithAuthorizeSSHRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error) Option {
|
||||
return func(a *Authority) error {
|
||||
a.authorizeSSHRenewFunc = fn
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithSSHBastionFunc sets a custom function to get the bastion for a
|
||||
// given user-host pair.
|
||||
func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option {
|
||||
|
|
195
authority/policy.go
Normal file
195
authority/policy.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package authority
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"go.step.sm/linkedca"
|
||||
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
authPolicy "github.com/smallstep/certificates/authority/policy"
|
||||
policy "github.com/smallstep/certificates/policy"
|
||||
)
|
||||
|
||||
func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
p, err := a.adminDB.GetAuthorityPolicy(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
if err := a.checkPolicy(ctx, adm, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := a.adminDB.CreateAuthorityPolicy(ctx, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := a.reloadPolicyEngines(ctx); err != nil {
|
||||
return nil, admin.WrapErrorISE(err, "error reloading policy engines when creating authority policy")
|
||||
}
|
||||
|
||||
return p, nil // TODO: return the newly stored policy
|
||||
}
|
||||
|
||||
func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
if err := a.checkPolicy(ctx, adm, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := a.adminDB.UpdateAuthorityPolicy(ctx, p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := a.reloadPolicyEngines(ctx); err != nil {
|
||||
return nil, admin.WrapErrorISE(err, "error reloading policy engines when updating authority policy")
|
||||
}
|
||||
|
||||
return p, nil // TODO: return the updated stored policy
|
||||
}
|
||||
|
||||
func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error {
|
||||
a.adminMutex.Lock()
|
||||
defer a.adminMutex.Unlock()
|
||||
|
||||
if err := a.adminDB.DeleteAuthorityPolicy(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := a.reloadPolicyEngines(ctx); err != nil {
|
||||
return admin.WrapErrorISE(err, "error reloading policy engines when deleting authority policy")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Authority) checkPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) error {
|
||||
|
||||
// convert the policy; return early if nil
|
||||
policyOptions := policyToCertificates(p)
|
||||
if policyOptions == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine, err := authPolicy.NewX509PolicyEngine(policyOptions.GetX509Options())
|
||||
if err != nil {
|
||||
return admin.WrapErrorISE(err, "error creating temporary policy engine")
|
||||
}
|
||||
|
||||
// TODO(hs): Provide option to force the policy, even when the admin subject would be locked out?
|
||||
|
||||
sans := []string{adm.Subject}
|
||||
if err := isAllowed(engine, sans); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(hs): perform the check for other admin subjects too?
|
||||
// What logic to use for that: do all admins need access? Only super admins? At least one?
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAllowed(engine authPolicy.X509Policy, sans []string) error {
|
||||
var (
|
||||
allowed bool
|
||||
err error
|
||||
)
|
||||
if allowed, err = engine.AreSANsAllowed(sans); err != nil {
|
||||
var policyErr *policy.NamePolicyError
|
||||
if isPolicyErr := errors.As(err, &policyErr); isPolicyErr && policyErr.Reason == policy.NotAuthorizedForThisName {
|
||||
return fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func policyToCertificates(p *linkedca.Policy) *authPolicy.Options {
|
||||
|
||||
// return early
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepare full policy struct
|
||||
opts := &authPolicy.Options{
|
||||
X509: &authPolicy.X509PolicyOptions{
|
||||
AllowedNames: &authPolicy.X509NameOptions{},
|
||||
DeniedNames: &authPolicy.X509NameOptions{},
|
||||
},
|
||||
SSH: &authPolicy.SSHPolicyOptions{
|
||||
Host: &authPolicy.SSHHostCertificateOptions{
|
||||
AllowedNames: &authPolicy.SSHNameOptions{},
|
||||
DeniedNames: &authPolicy.SSHNameOptions{},
|
||||
},
|
||||
User: &authPolicy.SSHUserCertificateOptions{
|
||||
AllowedNames: &authPolicy.SSHNameOptions{},
|
||||
DeniedNames: &authPolicy.SSHNameOptions{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// fill x509 policy configuration
|
||||
if p.X509 != nil {
|
||||
if p.X509.Allow != nil {
|
||||
opts.X509.AllowedNames.DNSDomains = p.X509.Allow.Dns
|
||||
opts.X509.AllowedNames.IPRanges = p.X509.Allow.Ips
|
||||
opts.X509.AllowedNames.EmailAddresses = p.X509.Allow.Emails
|
||||
opts.X509.AllowedNames.URIDomains = p.X509.Allow.Uris
|
||||
}
|
||||
if p.X509.Deny != nil {
|
||||
opts.X509.DeniedNames.DNSDomains = p.X509.Deny.Dns
|
||||
opts.X509.DeniedNames.IPRanges = p.X509.Deny.Ips
|
||||
opts.X509.DeniedNames.EmailAddresses = p.X509.Deny.Emails
|
||||
opts.X509.DeniedNames.URIDomains = p.X509.Deny.Uris
|
||||
}
|
||||
}
|
||||
|
||||
// fill ssh policy configuration
|
||||
if p.Ssh != nil {
|
||||
if p.Ssh.Host != nil {
|
||||
if p.Ssh.Host.Allow != nil {
|
||||
opts.SSH.Host.AllowedNames.DNSDomains = p.Ssh.Host.Allow.Dns
|
||||
opts.SSH.Host.AllowedNames.IPRanges = p.Ssh.Host.Allow.Ips
|
||||
opts.SSH.Host.AllowedNames.EmailAddresses = p.Ssh.Host.Allow.Principals
|
||||
}
|
||||
if p.Ssh.Host.Deny != nil {
|
||||
opts.SSH.Host.DeniedNames.DNSDomains = p.Ssh.Host.Deny.Dns
|
||||
opts.SSH.Host.DeniedNames.IPRanges = p.Ssh.Host.Deny.Ips
|
||||
opts.SSH.Host.DeniedNames.Principals = p.Ssh.Host.Deny.Principals
|
||||
}
|
||||
}
|
||||
if p.Ssh.User != nil {
|
||||
if p.Ssh.User.Allow != nil {
|
||||
opts.SSH.User.AllowedNames.EmailAddresses = p.Ssh.User.Allow.Emails
|
||||
opts.SSH.User.AllowedNames.Principals = p.Ssh.User.Allow.Principals
|
||||
}
|
||||
if p.Ssh.User.Deny != nil {
|
||||
opts.SSH.User.DeniedNames.EmailAddresses = p.Ssh.User.Deny.Emails
|
||||
opts.SSH.User.DeniedNames.Principals = p.Ssh.User.Deny.Principals
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
188
authority/policy/options.go
Normal file
188
authority/policy/options.go
Normal file
|
@ -0,0 +1,188 @@
|
|||
package policy
|
||||
|
||||
// Options is a container for authority level x509 and SSH
|
||||
// policy configuration.
|
||||
type Options struct {
|
||||
X509 *X509PolicyOptions `json:"x509,omitempty"`
|
||||
SSH *SSHPolicyOptions `json:"ssh,omitempty"`
|
||||
}
|
||||
|
||||
// GetX509Options returns the x509 authority level policy
|
||||
// configuration
|
||||
func (o *Options) GetX509Options() *X509PolicyOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.X509
|
||||
}
|
||||
|
||||
// GetSSHOptions returns the SSH authority level policy
|
||||
// configuration
|
||||
func (o *Options) GetSSHOptions() *SSHPolicyOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.SSH
|
||||
}
|
||||
|
||||
// X509PolicyOptionsInterface is an interface for providers
|
||||
// of x509 allowed and denied names.
|
||||
type X509PolicyOptionsInterface interface {
|
||||
GetAllowedNameOptions() *X509NameOptions
|
||||
GetDeniedNameOptions() *X509NameOptions
|
||||
}
|
||||
|
||||
// X509PolicyOptions is a container for x509 allowed and denied
|
||||
// names.
|
||||
type X509PolicyOptions struct {
|
||||
// AllowedNames contains the x509 allowed names
|
||||
AllowedNames *X509NameOptions `json:"allow,omitempty"`
|
||||
// DeniedNames contains the x509 denied names
|
||||
DeniedNames *X509NameOptions `json:"deny,omitempty"`
|
||||
}
|
||||
|
||||
// X509NameOptions models the X509 name policy configuration.
|
||||
type X509NameOptions struct {
|
||||
DNSDomains []string `json:"dns,omitempty"`
|
||||
IPRanges []string `json:"ip,omitempty"`
|
||||
EmailAddresses []string `json:"email,omitempty"`
|
||||
URIDomains []string `json:"uri,omitempty"`
|
||||
}
|
||||
|
||||
// HasNames checks if the AllowedNameOptions has one or more
|
||||
// names configured.
|
||||
func (o *X509NameOptions) HasNames() bool {
|
||||
return len(o.DNSDomains) > 0 ||
|
||||
len(o.IPRanges) > 0 ||
|
||||
len(o.EmailAddresses) > 0 ||
|
||||
len(o.URIDomains) > 0
|
||||
}
|
||||
|
||||
// SSHPolicyOptionsInterface is an interface for providers of
|
||||
// SSH user and host name policy configuration.
|
||||
type SSHPolicyOptionsInterface interface {
|
||||
GetAllowedUserNameOptions() *SSHNameOptions
|
||||
GetDeniedUserNameOptions() *SSHNameOptions
|
||||
GetAllowedHostNameOptions() *SSHNameOptions
|
||||
GetDeniedHostNameOptions() *SSHNameOptions
|
||||
}
|
||||
|
||||
// SSHPolicyOptions is a container for SSH user and host policy
|
||||
// configuration
|
||||
type SSHPolicyOptions struct {
|
||||
// User contains SSH user certificate options.
|
||||
User *SSHUserCertificateOptions `json:"user,omitempty"`
|
||||
// Host contains SSH host certificate options.
|
||||
Host *SSHHostCertificateOptions `json:"host,omitempty"`
|
||||
}
|
||||
|
||||
// GetAllowedNameOptions returns x509 allowed name policy configuration
|
||||
func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.AllowedNames
|
||||
}
|
||||
|
||||
// GetDeniedNameOptions returns the x509 denied name policy configuration
|
||||
func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.DeniedNames
|
||||
}
|
||||
|
||||
// GetAllowedUserNameOptions returns the SSH allowed user name policy
|
||||
// configuration.
|
||||
func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
if o.User == nil {
|
||||
return nil
|
||||
}
|
||||
return o.User.AllowedNames
|
||||
}
|
||||
|
||||
// GetDeniedUserNameOptions returns the SSH denied user name policy
|
||||
// configuration.
|
||||
func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
if o.User == nil {
|
||||
return nil
|
||||
}
|
||||
return o.User.DeniedNames
|
||||
}
|
||||
|
||||
// GetAllowedHostNameOptions returns the SSH allowed host name policy
|
||||
// configuration.
|
||||
func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
if o.Host == nil {
|
||||
return nil
|
||||
}
|
||||
return o.Host.AllowedNames
|
||||
}
|
||||
|
||||
// GetDeniedHostNameOptions returns the SSH denied host name policy
|
||||
// configuration.
|
||||
func (o *SSHPolicyOptions) GetDeniedHostNameOptions() *SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
if o.Host == nil {
|
||||
return nil
|
||||
}
|
||||
return o.Host.DeniedNames
|
||||
}
|
||||
|
||||
// SSHUserCertificateOptions is a collection of SSH user certificate options.
|
||||
type SSHUserCertificateOptions struct {
|
||||
// AllowedNames contains the names the provisioner is authorized to sign
|
||||
AllowedNames *SSHNameOptions `json:"allow,omitempty"`
|
||||
// DeniedNames contains the names the provisioner is not authorized to sign
|
||||
DeniedNames *SSHNameOptions `json:"deny,omitempty"`
|
||||
}
|
||||
|
||||
// SSHHostCertificateOptions is a collection of SSH host certificate options.
|
||||
// It's an alias of SSHUserCertificateOptions, as the options are the same
|
||||
// for both types of certificates.
|
||||
type SSHHostCertificateOptions SSHUserCertificateOptions
|
||||
|
||||
// SSHNameOptions models the SSH name policy configuration.
|
||||
type SSHNameOptions struct {
|
||||
DNSDomains []string `json:"dns,omitempty"`
|
||||
IPRanges []string `json:"ip,omitempty"`
|
||||
EmailAddresses []string `json:"email,omitempty"`
|
||||
Principals []string `json:"principal,omitempty"`
|
||||
}
|
||||
|
||||
// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the
|
||||
// names that a provisioner is authorized to sign SSH certificates for.
|
||||
func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.AllowedNames
|
||||
}
|
||||
|
||||
// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the
|
||||
// names that a provisioner is NOT authorized to sign SSH certificates for.
|
||||
func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.DeniedNames
|
||||
}
|
||||
|
||||
// HasNames checks if the SSHNameOptions has one or more
|
||||
// names configured.
|
||||
func (o *SSHNameOptions) HasNames() bool {
|
||||
return len(o.DNSDomains) > 0 ||
|
||||
len(o.EmailAddresses) > 0 ||
|
||||
len(o.Principals) > 0
|
||||
}
|
134
authority/policy/policy.go
Normal file
134
authority/policy/policy.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/smallstep/certificates/policy"
|
||||
)
|
||||
|
||||
// X509Policy is an alias for policy.X509NamePolicyEngine
|
||||
type X509Policy policy.X509NamePolicyEngine
|
||||
|
||||
// UserPolicy is an alias for policy.SSHNamePolicyEngine
|
||||
type UserPolicy policy.SSHNamePolicyEngine
|
||||
|
||||
// HostPolicy is an alias for policy.SSHNamePolicyEngine
|
||||
type HostPolicy policy.SSHNamePolicyEngine
|
||||
|
||||
// NewX509PolicyEngine creates a new x509 name policy engine
|
||||
func NewX509PolicyEngine(policyOptions X509PolicyOptionsInterface) (X509Policy, error) {
|
||||
|
||||
// return early if no policy engine options to configure
|
||||
if policyOptions == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
options := []policy.NamePolicyOption{}
|
||||
|
||||
allowed := policyOptions.GetAllowedNameOptions()
|
||||
if allowed != nil && allowed.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithPermittedDNSDomains(allowed.DNSDomains),
|
||||
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges),
|
||||
policy.WithPermittedEmailAddresses(allowed.EmailAddresses),
|
||||
policy.WithPermittedURIDomains(allowed.URIDomains),
|
||||
)
|
||||
}
|
||||
|
||||
denied := policyOptions.GetDeniedNameOptions()
|
||||
if denied != nil && denied.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithExcludedDNSDomains(denied.DNSDomains),
|
||||
policy.WithExcludedIPsOrCIDRs(denied.IPRanges),
|
||||
policy.WithExcludedEmailAddresses(denied.EmailAddresses),
|
||||
policy.WithExcludedURIDomains(denied.URIDomains),
|
||||
)
|
||||
}
|
||||
|
||||
// ensure no policy engine is returned when no name options were provided
|
||||
if len(options) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// enable x509 Subject Common Name validation by default
|
||||
options = append(options, policy.WithSubjectCommonNameVerification())
|
||||
|
||||
return policy.New(options...)
|
||||
}
|
||||
|
||||
type sshPolicyEngineType string
|
||||
|
||||
const (
|
||||
UserPolicyEngineType sshPolicyEngineType = "user"
|
||||
HostPolicyEngineType sshPolicyEngineType = "host"
|
||||
)
|
||||
|
||||
// newSSHUserPolicyEngine creates a new SSH user certificate policy engine
|
||||
func NewSSHUserPolicyEngine(policyOptions SSHPolicyOptionsInterface) (UserPolicy, error) {
|
||||
policyEngine, err := newSSHPolicyEngine(policyOptions, UserPolicyEngineType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return policyEngine, nil
|
||||
}
|
||||
|
||||
// newSSHHostPolicyEngine create a new SSH host certificate policy engine
|
||||
func NewSSHHostPolicyEngine(policyOptions SSHPolicyOptionsInterface) (HostPolicy, error) {
|
||||
policyEngine, err := newSSHPolicyEngine(policyOptions, HostPolicyEngineType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return policyEngine, nil
|
||||
}
|
||||
|
||||
// newSSHPolicyEngine creates a new SSH name policy engine
|
||||
func newSSHPolicyEngine(policyOptions SSHPolicyOptionsInterface, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) {
|
||||
|
||||
// return early if no policy engine options to configure
|
||||
if policyOptions == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
allowed *SSHNameOptions
|
||||
denied *SSHNameOptions
|
||||
)
|
||||
|
||||
switch typ {
|
||||
case UserPolicyEngineType:
|
||||
allowed = policyOptions.GetAllowedUserNameOptions()
|
||||
denied = policyOptions.GetDeniedUserNameOptions()
|
||||
case HostPolicyEngineType:
|
||||
allowed = policyOptions.GetAllowedHostNameOptions()
|
||||
denied = policyOptions.GetDeniedHostNameOptions()
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ)
|
||||
}
|
||||
|
||||
options := []policy.NamePolicyOption{}
|
||||
|
||||
if allowed != nil && allowed.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithPermittedDNSDomains(allowed.DNSDomains),
|
||||
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges),
|
||||
policy.WithPermittedEmailAddresses(allowed.EmailAddresses),
|
||||
policy.WithPermittedPrincipals(allowed.Principals),
|
||||
)
|
||||
}
|
||||
|
||||
if denied != nil && denied.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithExcludedDNSDomains(denied.DNSDomains),
|
||||
policy.WithExcludedIPsOrCIDRs(denied.IPRanges),
|
||||
policy.WithExcludedEmailAddresses(denied.EmailAddresses),
|
||||
policy.WithExcludedPrincipals(denied.Principals),
|
||||
)
|
||||
}
|
||||
|
||||
// ensure no policy engine is returned when no name options were provided
|
||||
if len(options) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return policy.New(options...)
|
||||
}
|
|
@ -7,8 +7,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"github.com/smallstep/certificates/policy"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
)
|
||||
|
||||
// ACME is the acme provisioner type, an entity that can authorize the ACME
|
||||
|
@ -26,8 +26,9 @@ type ACME struct {
|
|||
RequireEAB bool `json:"requireEAB,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
claimer *Claimer
|
||||
x509Policy policy.X509NamePolicyEngine
|
||||
|
||||
ctl *Controller
|
||||
x509Policy policy.X509Policy
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
|
@ -72,7 +73,7 @@ func (p *ACME) GetOptions() *Options {
|
|||
// DefaultTLSCertDuration returns the default TLS cert duration enforced by
|
||||
// the provisioner.
|
||||
func (p *ACME) DefaultTLSCertDuration() time.Duration {
|
||||
return p.claimer.DefaultTLSCertDuration()
|
||||
return p.ctl.Claimer.DefaultTLSCertDuration()
|
||||
}
|
||||
|
||||
// Init initializes and validates the fields of an ACME type.
|
||||
|
@ -84,19 +85,13 @@ func (p *ACME) Init(config Config) (err error) {
|
|||
return errors.New("provisioner name cannot be empty")
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the x509 allow/deny policy engine
|
||||
// TODO(hs): ensure no race conditions happen when reloading settings and requesting certs?
|
||||
// TODO(hs): implement memoization strategy, so that reloading is not required when no changes were made to allow/deny?
|
||||
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
return
|
||||
}
|
||||
|
||||
// ACMEIdentifierType encodes ACME Identifier types
|
||||
|
@ -115,20 +110,22 @@ type ACMEIdentifier struct {
|
|||
Value string
|
||||
}
|
||||
|
||||
// AuthorizeOrderIdentifiers verifies the provisioner is authorized to issue a
|
||||
// certificate for the Identifiers provided in an Order.
|
||||
func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier string) error {
|
||||
// AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a
|
||||
// certificate for an ACME Order Identifier.
|
||||
func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier ACMEIdentifier) error {
|
||||
|
||||
// identifier is allowed if no policy is configured
|
||||
if p.x509Policy == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// assuming only valid identifiers (IP or DNS) are provided
|
||||
var err error
|
||||
if ip := net.ParseIP(identifier); ip != nil {
|
||||
_, err = p.x509Policy.IsIPAllowed(ip)
|
||||
} else {
|
||||
_, err = p.x509Policy.IsDNSAllowed(identifier)
|
||||
switch identifier.Type {
|
||||
case IP:
|
||||
_, err = p.x509Policy.IsIPAllowed(net.ParseIP(identifier.Value))
|
||||
case DNS:
|
||||
_, err = p.x509Policy.IsDNSAllowed(identifier.Value)
|
||||
}
|
||||
|
||||
return err
|
||||
|
@ -142,10 +139,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeACME, p.Name, ""),
|
||||
newForceCNOption(p.ForceCN),
|
||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||
// validators
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.x509Policy),
|
||||
}
|
||||
|
||||
|
@ -166,8 +163,5 @@ func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error {
|
|||
// revocation status. Just confirms that the provisioner that created the
|
||||
// certificate was configured to allow renewals.
|
||||
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||
}
|
||||
|
|
|
@ -91,6 +91,7 @@ func TestACME_Init(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestACME_AuthorizeRenew(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Second)
|
||||
type test struct {
|
||||
p *ACME
|
||||
cert *x509.Certificate
|
||||
|
@ -104,13 +105,16 @@ func TestACME_AuthorizeRenew(t *testing.T) {
|
|||
// disable renewal
|
||||
disable := true
|
||||
p.Claims = &Claims{DisableRenewal: &disable}
|
||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
cert: &x509.Certificate{},
|
||||
cert: &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
},
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()),
|
||||
err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
@ -118,7 +122,10 @@ func TestACME_AuthorizeRenew(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
cert: &x509.Certificate{},
|
||||
cert: &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -172,18 +179,18 @@ func TestACME_AuthorizeSign(t *testing.T) {
|
|||
for _, o := range opts {
|
||||
switch v := o.(type) {
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, int(TypeACME))
|
||||
assert.Equals(t, v.Type, TypeACME)
|
||||
assert.Equals(t, v.Name, tc.p.GetName())
|
||||
assert.Equals(t, v.CredentialID, "")
|
||||
assert.Len(t, 0, v.KeyValuePairs)
|
||||
case *forceCNOption:
|
||||
assert.Equals(t, v.ForceCN, tc.p.ForceCN)
|
||||
case profileDefaultDuration:
|
||||
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration())
|
||||
assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
|
||||
case defaultPublicKeyValidator:
|
||||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
|
||||
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
|
||||
case *x509NamePolicyValidator:
|
||||
assert.Equals(t, nil, v.policyEngine)
|
||||
default:
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
|
@ -264,11 +265,10 @@ type AWS struct {
|
|||
IIDRoots string `json:"iidRoots,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
claimer *Claimer
|
||||
config *awsConfig
|
||||
audiences Audiences
|
||||
x509Policy x509PolicyEngine
|
||||
sshHostPolicy *hostPolicyEngine
|
||||
ctl *Controller
|
||||
x509Policy policy.X509Policy
|
||||
sshHostPolicy policy.HostPolicy
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
|
@ -402,15 +402,11 @@ func (p *AWS) Init(config Config) (err error) {
|
|||
case p.InstanceAge.Value() < 0:
|
||||
return errors.New("provisioner instanceAge cannot be negative")
|
||||
}
|
||||
// Update claims with global ones
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add default config
|
||||
if p.config, err = newAWSConfig(p.IIDRoots); err != nil {
|
||||
return err
|
||||
}
|
||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
|
||||
// validate IMDS versions
|
||||
if len(p.IMDSVersions) == 0 {
|
||||
|
@ -428,16 +424,18 @@ func (p *AWS) Init(config Config) (err error) {
|
|||
}
|
||||
|
||||
// Initialize the x509 allow/deny policy engine
|
||||
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for host certificates
|
||||
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
return
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token and returns the sign options that
|
||||
|
@ -485,11 +483,11 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
|
||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||
// validators
|
||||
defaultPublicKeyValidator{},
|
||||
commonNameValidator(payload.Claims.Subject),
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.x509Policy),
|
||||
), nil
|
||||
}
|
||||
|
@ -499,10 +497,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// revocation status. Just confirms that the provisioner that created the
|
||||
// certificate was configured to allow renewals.
|
||||
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||
}
|
||||
|
||||
// assertConfig initializes the config if it has not been initialized
|
||||
|
@ -561,7 +556,7 @@ func (p *AWS) readURL(url string) ([]byte, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("Request for metadata returned non-successful status code %d",
|
||||
return nil, fmt.Errorf("request for metadata returned non-successful status code %d",
|
||||
resp.StatusCode)
|
||||
}
|
||||
|
||||
|
@ -594,7 +589,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
|
|||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode)
|
||||
return nil, fmt.Errorf("request for API token returned non-successful status code %d", resp.StatusCode)
|
||||
}
|
||||
token, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
|
@ -677,7 +672,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
}
|
||||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(payload.Audience, p.audiences.Sign) {
|
||||
if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) {
|
||||
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
|
||||
}
|
||||
|
||||
|
@ -717,7 +712,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
|
|||
|
||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName())
|
||||
}
|
||||
claims, err := p.authorizeToken(token)
|
||||
|
@ -765,11 +760,11 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
// Validate user SignSSHOptions.
|
||||
sshCertOptionsValidator(defaults),
|
||||
// Set the validity bounds if not set.
|
||||
&sshDefaultDuration{p.claimer},
|
||||
&sshDefaultDuration{p.ctl.Claimer},
|
||||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
|
|
|
@ -677,18 +677,18 @@ func TestAWS_AuthorizeSign(t *testing.T) {
|
|||
switch v := o.(type) {
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, int(TypeAWS))
|
||||
assert.Equals(t, v.Type, TypeAWS)
|
||||
assert.Equals(t, v.Name, tt.aws.GetName())
|
||||
assert.Equals(t, v.CredentialID, tt.aws.Accounts[0])
|
||||
assert.Len(t, 2, v.KeyValuePairs)
|
||||
case profileDefaultDuration:
|
||||
assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration())
|
||||
assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration())
|
||||
case commonNameValidator:
|
||||
assert.Equals(t, string(v), tt.args.cn)
|
||||
case defaultPublicKeyValidator:
|
||||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration())
|
||||
assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration())
|
||||
case ipAddressesValidator:
|
||||
assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")})
|
||||
case emailAddressesValidator:
|
||||
|
@ -728,7 +728,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
|||
// disable sshCA
|
||||
disable := false
|
||||
p3.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
||||
p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
|
||||
|
@ -749,7 +749,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
|||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||
expectedHostOptions := &SignSSHOptions{
|
||||
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
|
@ -826,6 +826,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAWS_AuthorizeRenew(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Second)
|
||||
p1, err := generateAWS()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateAWS()
|
||||
|
@ -834,7 +835,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
|
|||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
|
@ -847,8 +848,14 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
|
|||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
||||
{"ok", p1, args{&x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{&x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -13,10 +13,13 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
|
||||
|
@ -30,7 +33,7 @@ const azureDefaultAudience = "https://management.azure.com/"
|
|||
|
||||
// azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim.
|
||||
// Using case insensitive as resourceGroups appears as resourcegroups.
|
||||
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Compute/virtualMachines/([^/]+)$`)
|
||||
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
|
||||
|
||||
type azureConfig struct {
|
||||
oidcDiscoveryURL string
|
||||
|
@ -96,12 +99,12 @@ type Azure struct {
|
|||
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
claimer *Claimer
|
||||
config *azureConfig
|
||||
oidcConfig openIDConfiguration
|
||||
keyStore *keyStore
|
||||
x509Policy x509PolicyEngine
|
||||
sshHostPolicy *hostPolicyEngine
|
||||
ctl *Controller
|
||||
x509Policy policy.X509Policy
|
||||
sshHostPolicy policy.HostPolicy
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier.
|
||||
|
@ -205,37 +208,34 @@ func (p *Azure) Init(config Config) (err error) {
|
|||
case p.Audience == "": // use default audience
|
||||
p.Audience = azureDefaultAudience
|
||||
}
|
||||
|
||||
// Initialize config
|
||||
p.assertConfig()
|
||||
|
||||
// Update claims with global ones
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode and validate openid-configuration endpoint
|
||||
if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
|
||||
return err
|
||||
if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
|
||||
return
|
||||
}
|
||||
if err := p.oidcConfig.Validate(); err != nil {
|
||||
return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL)
|
||||
}
|
||||
// Get JWK key set
|
||||
if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil {
|
||||
return err
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize the x509 allow/deny policy engine
|
||||
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for host certificates
|
||||
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
return
|
||||
}
|
||||
|
||||
// authorizeToken returns the claims, name, group, subscription, identityObjectID, error.
|
||||
|
@ -275,11 +275,14 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, str
|
|||
}
|
||||
|
||||
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
|
||||
if len(re) != 4 {
|
||||
if len(re) != 5 {
|
||||
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
|
||||
}
|
||||
|
||||
var subscription, group, name string
|
||||
identityObjectID := claims.ObjectID
|
||||
subscription, group, name := re[1], re[2], re[3]
|
||||
subscription, group, name = re[1], re[2], re[4]
|
||||
|
||||
return &claims, name, group, subscription, identityObjectID, nil
|
||||
}
|
||||
|
||||
|
@ -367,10 +370,10 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
|
||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||
// validators
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.x509Policy),
|
||||
), nil
|
||||
}
|
||||
|
@ -380,15 +383,12 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
// revocation status. Just confirms that the provisioner that created the
|
||||
// certificate was configured to allow renewals.
|
||||
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||
}
|
||||
|
||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName())
|
||||
}
|
||||
|
||||
|
@ -433,11 +433,11 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
// Validate user SignSSHOptions.
|
||||
sshCertOptionsValidator(defaults),
|
||||
// Set the validity bounds if not set.
|
||||
&sshDefaultDuration{p.claimer},
|
||||
&sshDefaultDuration{p.ctl.Claimer},
|
||||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
|
|
|
@ -95,7 +95,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
|
@ -237,7 +237,7 @@ func TestAzure_authorizeToken(t *testing.T) {
|
|||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
|
||||
assert.FatalError(t, err)
|
||||
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||
time.Now(), jwk)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
|
@ -252,7 +252,7 @@ func TestAzure_authorizeToken(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
|
||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
|
@ -267,7 +267,7 @@ func TestAzure_authorizeToken(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
||||
"foo", "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
"foo", "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
|
@ -321,7 +321,7 @@ func TestAzure_authorizeToken(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
defer srv.Close()
|
||||
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
|
||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||
time.Now(), &p.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
|
@ -437,28 +437,28 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
|
||||
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
|
||||
failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience",
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||
time.Now(), &p1.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||
time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||
time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0])
|
||||
assert.FatalError(t, err)
|
||||
failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine",
|
||||
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
|
||||
time.Now(), badKey)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
|
@ -506,18 +506,18 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
switch v := o.(type) {
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, int(TypeAzure))
|
||||
assert.Equals(t, v.Type, TypeAzure)
|
||||
assert.Equals(t, v.Name, tt.azure.GetName())
|
||||
assert.Equals(t, v.CredentialID, tt.azure.TenantID)
|
||||
assert.Len(t, 0, v.KeyValuePairs)
|
||||
case profileDefaultDuration:
|
||||
assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration())
|
||||
assert.Equals(t, time.Duration(v), tt.azure.ctl.Claimer.DefaultTLSCertDuration())
|
||||
case commonNameValidator:
|
||||
assert.Equals(t, string(v), "virtualMachine")
|
||||
case defaultPublicKeyValidator:
|
||||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration())
|
||||
assert.Equals(t, v.min, tt.azure.ctl.Claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.azure.ctl.Claimer.MaxTLSCertDuration())
|
||||
case ipAddressesValidator:
|
||||
assert.Equals(t, v, nil)
|
||||
case emailAddressesValidator:
|
||||
|
@ -538,6 +538,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAzure_AuthorizeRenew(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Second)
|
||||
p1, err := generateAzure()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateAzure()
|
||||
|
@ -546,7 +547,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
|
|||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
|
@ -559,8 +560,14 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
|
|||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
||||
{"ok", p1, args{&x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{&x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -597,7 +604,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
|
|||
// disable sshCA
|
||||
disable := false
|
||||
p3.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
||||
p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := p1.GetIdentityToken("subject", "caURL")
|
||||
|
@ -618,7 +625,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
|
|||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||
expectedHostOptions := &SignSSHOptions{
|
||||
CertType: "host", Principals: []string{"virtualMachine"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
|
|
|
@ -13,7 +13,7 @@ type Claims struct {
|
|||
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
|
||||
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
|
||||
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
|
||||
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
||||
|
||||
// SSH CA properties
|
||||
MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"`
|
||||
MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"`
|
||||
|
@ -22,6 +22,10 @@ type Claims struct {
|
|||
MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"`
|
||||
DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"`
|
||||
EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
|
||||
|
||||
// Renewal properties
|
||||
DisableRenewal *bool `json:"disableRenewal,omitempty"`
|
||||
AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"`
|
||||
}
|
||||
|
||||
// Claimer is the type that controls claims. It provides an interface around the
|
||||
|
@ -40,12 +44,13 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
|
|||
// Claims returns the merge of the inner and global claims.
|
||||
func (c *Claimer) Claims() Claims {
|
||||
disableRenewal := c.IsDisableRenewal()
|
||||
allowRenewAfterExpiry := c.AllowRenewAfterExpiry()
|
||||
enableSSHCA := c.IsSSHCAEnabled()
|
||||
|
||||
return Claims{
|
||||
MinTLSDur: &Duration{c.MinTLSCertDuration()},
|
||||
MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
|
||||
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
|
||||
DisableRenewal: &disableRenewal,
|
||||
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
|
||||
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
|
||||
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
|
||||
|
@ -53,6 +58,8 @@ func (c *Claimer) Claims() Claims {
|
|||
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
|
||||
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
|
||||
EnableSSHCA: &enableSSHCA,
|
||||
DisableRenewal: &disableRenewal,
|
||||
AllowRenewAfterExpiry: &allowRenewAfterExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,6 +109,16 @@ func (c *Claimer) IsDisableRenewal() bool {
|
|||
return *c.claims.DisableRenewal
|
||||
}
|
||||
|
||||
// AllowRenewAfterExpiry returns if the renewal flow is authorized if the
|
||||
// certificate is expired. If the property is not set within the provisioner
|
||||
// then the global value from the authority configuration will be used.
|
||||
func (c *Claimer) AllowRenewAfterExpiry() bool {
|
||||
if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil {
|
||||
return *c.global.AllowRenewAfterExpiry
|
||||
}
|
||||
return *c.claims.AllowRenewAfterExpiry
|
||||
}
|
||||
|
||||
// DefaultSSHCertDuration returns the default SSH certificate duration for the
|
||||
// given certificate type.
|
||||
func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) {
|
||||
|
|
|
@ -152,8 +152,8 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
|
|||
// proper id to load the provisioner.
|
||||
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) {
|
||||
for _, e := range cert.Extensions {
|
||||
if e.Id.Equal(stepOIDProvisioner) {
|
||||
var provisioner stepProvisionerASN1
|
||||
if e.Id.Equal(StepOIDProvisioner) {
|
||||
var provisioner extensionASN1
|
||||
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
|
|
@ -147,6 +147,17 @@ func TestCollection_LoadByToken(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCollection_LoadByCertificate(t *testing.T) {
|
||||
mustExtension := func(typ Type, name, credentialID string) pkix.Extension {
|
||||
e := Extension{
|
||||
Type: typ, Name: name, CredentialID: credentialID,
|
||||
}
|
||||
ext, err := e.ToExtension()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return ext
|
||||
}
|
||||
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
|
@ -159,30 +170,21 @@ func TestCollection_LoadByCertificate(t *testing.T) {
|
|||
byName.Store(p2.GetName(), p2)
|
||||
byName.Store(p3.GetName(), p3)
|
||||
|
||||
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
|
||||
assert.FatalError(t, err)
|
||||
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
|
||||
assert.FatalError(t, err)
|
||||
ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "")
|
||||
assert.FatalError(t, err)
|
||||
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
|
||||
assert.FatalError(t, err)
|
||||
|
||||
ok1Cert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{ok1Ext},
|
||||
Extensions: []pkix.Extension{mustExtension(1, p1.Name, p1.Key.KeyID)},
|
||||
}
|
||||
ok2Cert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{ok2Ext},
|
||||
Extensions: []pkix.Extension{mustExtension(2, p2.Name, p2.ClientID)},
|
||||
}
|
||||
ok3Cert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{ok3Ext},
|
||||
Extensions: []pkix.Extension{mustExtension(TypeACME, p3.Name, "")},
|
||||
}
|
||||
notFoundCert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{notFoundExt},
|
||||
Extensions: []pkix.Extension{mustExtension(1, "foo", "bar")},
|
||||
}
|
||||
badCert := &x509.Certificate{
|
||||
Extensions: []pkix.Extension{
|
||||
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")},
|
||||
{Id: StepOIDProvisioner, Critical: false, Value: []byte("foobar")},
|
||||
},
|
||||
}
|
||||
|
||||
|
|
194
authority/provisioner/controller.go
Normal file
194
authority/provisioner/controller.go
Normal file
|
@ -0,0 +1,194 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Controller wraps a provisioner with other attributes useful in callback
|
||||
// functions.
|
||||
type Controller struct {
|
||||
Interface
|
||||
Audiences *Audiences
|
||||
Claimer *Claimer
|
||||
IdentityFunc GetIdentityFunc
|
||||
AuthorizeRenewFunc AuthorizeRenewFunc
|
||||
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||
}
|
||||
|
||||
// NewController initializes a new provisioner controller.
|
||||
func NewController(p Interface, claims *Claims, config Config) (*Controller, error) {
|
||||
claimer, err := NewClaimer(claims, config.Claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Controller{
|
||||
Interface: p,
|
||||
Audiences: &config.Audiences,
|
||||
Claimer: claimer,
|
||||
IdentityFunc: config.GetIdentityFunc,
|
||||
AuthorizeRenewFunc: config.AuthorizeRenewFunc,
|
||||
AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetIdentity returns the identity for a given email.
|
||||
func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) {
|
||||
if c.IdentityFunc != nil {
|
||||
return c.IdentityFunc(ctx, c.Interface, email)
|
||||
}
|
||||
return DefaultIdentityFunc(ctx, c.Interface, email)
|
||||
}
|
||||
|
||||
// AuthorizeRenew returns nil if the given cert can be renewed, returns an error
|
||||
// otherwise.
|
||||
func (c *Controller) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if c.AuthorizeRenewFunc != nil {
|
||||
return c.AuthorizeRenewFunc(ctx, c, cert)
|
||||
}
|
||||
return DefaultAuthorizeRenew(ctx, c, cert)
|
||||
}
|
||||
|
||||
// AuthorizeSSHRenew returns nil if the given cert can be renewed, returns an
|
||||
// error otherwise.
|
||||
func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificate) error {
|
||||
if c.AuthorizeSSHRenewFunc != nil {
|
||||
return c.AuthorizeSSHRenewFunc(ctx, c, cert)
|
||||
}
|
||||
return DefaultAuthorizeSSHRenew(ctx, c, cert)
|
||||
}
|
||||
|
||||
// Identity is the type representing an externally supplied identity that is used
|
||||
// by provisioners to populate certificate fields.
|
||||
type Identity struct {
|
||||
Usernames []string `json:"usernames"`
|
||||
Permissions `json:"permissions"`
|
||||
}
|
||||
|
||||
// GetIdentityFunc is a function that returns an identity.
|
||||
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
|
||||
|
||||
// AuthorizeRenewFunc is a function that returns nil if the renewal of a
|
||||
// certificate is enabled.
|
||||
type AuthorizeRenewFunc func(ctx context.Context, p *Controller, cert *x509.Certificate) error
|
||||
|
||||
// AuthorizeSSHRenewFunc is a function that returns nil if the renewal of the
|
||||
// given SSH certificate is enabled.
|
||||
type AuthorizeSSHRenewFunc func(ctx context.Context, p *Controller, cert *ssh.Certificate) error
|
||||
|
||||
// DefaultIdentityFunc return a default identity depending on the provisioner
|
||||
// type. For OIDC email is always present and the usernames might
|
||||
// contain empty strings.
|
||||
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
switch k := p.(type) {
|
||||
case *OIDC:
|
||||
// OIDC principals would be:
|
||||
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
|
||||
// 2. Sanitized local.
|
||||
// 3. Raw local (if different).
|
||||
// 4. Email address.
|
||||
name := SanitizeSSHUserPrincipal(email)
|
||||
if !sshUserRegex.MatchString(name) {
|
||||
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
|
||||
}
|
||||
usernames := []string{name}
|
||||
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||
usernames = append(usernames, email[:i])
|
||||
}
|
||||
usernames = append(usernames, email)
|
||||
return &Identity{
|
||||
Usernames: SanitizeStringSlices(usernames),
|
||||
}, nil
|
||||
default:
|
||||
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultAuthorizeRenew is the default implementation of AuthorizeRenew. It
|
||||
// will return an error if the provisioner has the renewal disabled, if the
|
||||
// certificate is not yet valid or if the certificate is expired and renew after
|
||||
// expiry is disabled.
|
||||
func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||
if p.Claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
|
||||
}
|
||||
|
||||
now := time.Now().Truncate(time.Second)
|
||||
if now.Before(cert.NotBefore) {
|
||||
return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano))
|
||||
}
|
||||
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() {
|
||||
return errs.Unauthorized("certificate has expired")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultAuthorizeSSHRenew is the default implementation of AuthorizeSSHRenew. It
|
||||
// will return an error if the provisioner has the renewal disabled, if the
|
||||
// certificate is not yet valid or if the certificate is expired and renew after
|
||||
// expiry is disabled.
|
||||
func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||
if p.Claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
|
||||
}
|
||||
|
||||
unixNow := time.Now().Unix()
|
||||
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
|
||||
return errs.Unauthorized("certificate is not yet valid")
|
||||
}
|
||||
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() {
|
||||
return errs.Unauthorized("certificate has expired")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
|
||||
|
||||
// SanitizeStringSlices removes duplicated an empty strings.
|
||||
func SanitizeStringSlices(original []string) []string {
|
||||
output := []string{}
|
||||
seen := make(map[string]struct{})
|
||||
for _, entry := range original {
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if _, value := seen[entry]; !value {
|
||||
seen[entry] = struct{}{}
|
||||
output = append(output, entry)
|
||||
}
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
// SanitizeSSHUserPrincipal grabs an email or a string with the format
|
||||
// local@domain and returns a sanitized version of the local, valid to be used
|
||||
// as a user name. If the email starts with a letter between a and z, the
|
||||
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
|
||||
func SanitizeSSHUserPrincipal(email string) string {
|
||||
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||
email = email[:i]
|
||||
}
|
||||
return strings.Map(func(r rune) rune {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
return r
|
||||
case r >= '0' && r <= '9':
|
||||
return r
|
||||
case r == '-':
|
||||
return '-'
|
||||
case r == '.': // drop dots
|
||||
return -1
|
||||
default:
|
||||
return '_'
|
||||
}
|
||||
}, strings.ToLower(email))
|
||||
}
|
391
authority/provisioner/controller_test.go
Normal file
391
authority/provisioner/controller_test.go
Normal file
|
@ -0,0 +1,391 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var trueValue = true
|
||||
|
||||
func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer {
|
||||
t.Helper()
|
||||
c, err := NewClaimer(claims, global)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return c
|
||||
}
|
||||
func mustDuration(t *testing.T, s string) *Duration {
|
||||
t.Helper()
|
||||
d, err := NewDuration(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func TestNewController(t *testing.T) {
|
||||
type args struct {
|
||||
p Interface
|
||||
claims *Claims
|
||||
config Config
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Controller
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{&JWK{}, nil, Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}}, &Controller{
|
||||
Interface: &JWK{},
|
||||
Audiences: &testAudiences,
|
||||
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
||||
}, false},
|
||||
{"ok with claims", args{&JWK{}, &Claims{
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}, Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}}, &Controller{
|
||||
Interface: &JWK{},
|
||||
Audiences: &testAudiences,
|
||||
Claimer: mustClaimer(t, &Claims{
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
}, globalProvisionerClaims),
|
||||
}, false},
|
||||
{"fail claimer", args{&JWK{}, &Claims{
|
||||
MinTLSDur: mustDuration(t, "24h"),
|
||||
MaxTLSDur: mustDuration(t, "2h"),
|
||||
}, Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewController() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestController_GetIdentity(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
type fields struct {
|
||||
Interface Interface
|
||||
IdentityFunc GetIdentityFunc
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
email string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *Identity
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{
|
||||
Usernames: []string{"jane", "jane@doe.org"},
|
||||
}, false},
|
||||
{"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
return &Identity{Usernames: []string{"jane"}}, nil
|
||||
}}, args{ctx, "jane@doe.org"}, &Identity{
|
||||
Usernames: []string{"jane"},
|
||||
}, false},
|
||||
{"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true},
|
||||
{"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
return nil, fmt.Errorf("an error")
|
||||
}}, args{ctx, "jane@doe.org"}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Controller{
|
||||
Interface: tt.fields.Interface,
|
||||
IdentityFunc: tt.fields.IdentityFunc,
|
||||
}
|
||||
got, err := c.GetIdentity(tt.args.ctx, tt.args.email)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestController_AuthorizeRenew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now().Truncate(time.Second)
|
||||
type fields struct {
|
||||
Interface Interface
|
||||
Claimer *Claimer
|
||||
AuthorizeRenewFunc AuthorizeRenewFunc
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, false},
|
||||
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||
return nil
|
||||
}}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, false},
|
||||
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||
return nil
|
||||
}}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, false},
|
||||
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now.Add(-time.Hour),
|
||||
NotAfter: now.Add(-time.Minute),
|
||||
}}, false},
|
||||
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, true},
|
||||
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now.Add(time.Hour),
|
||||
NotAfter: now.Add(2 * time.Hour),
|
||||
}}, true},
|
||||
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now.Add(-time.Hour),
|
||||
NotAfter: now.Add(-time.Minute),
|
||||
}}, true},
|
||||
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
|
||||
return fmt.Errorf("an error")
|
||||
}}, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Controller{
|
||||
Interface: tt.fields.Interface,
|
||||
Claimer: tt.fields.Claimer,
|
||||
AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc,
|
||||
}
|
||||
if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestController_AuthorizeSSHRenew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
type fields struct {
|
||||
Interface Interface
|
||||
Claimer *Claimer
|
||||
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
cert *ssh.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, false},
|
||||
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||
return nil
|
||||
}}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, false},
|
||||
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||
return nil
|
||||
}}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, false},
|
||||
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||
}}, false},
|
||||
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, true},
|
||||
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
|
||||
}}, true},
|
||||
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||
}}, true},
|
||||
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
|
||||
return fmt.Errorf("an error")
|
||||
}}, args{ctx, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Controller{
|
||||
Interface: tt.fields.Interface,
|
||||
Claimer: tt.fields.Claimer,
|
||||
AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc,
|
||||
}
|
||||
if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAuthorizeRenew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now().Truncate(time.Second)
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
p *Controller
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
||||
}, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, false},
|
||||
{"ok renew after expiry", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
|
||||
}, &x509.Certificate{
|
||||
NotBefore: now.Add(-time.Hour),
|
||||
NotAfter: now.Add(-time.Minute),
|
||||
}}, false},
|
||||
{"fail disabled", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, true},
|
||||
{"fail not yet valid", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &x509.Certificate{
|
||||
NotBefore: now.Add(time.Hour),
|
||||
NotAfter: now.Add(2 * time.Hour),
|
||||
}}, true},
|
||||
{"fail expired", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &x509.Certificate{
|
||||
NotBefore: now.Add(-time.Hour),
|
||||
NotAfter: now.Add(-time.Minute),
|
||||
}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAuthorizeSSHRenew(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
now := time.Now()
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
p *Controller
|
||||
cert *ssh.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
|
||||
}, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, false},
|
||||
{"ok renew after expiry", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
|
||||
}, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||
}}, false},
|
||||
{"fail disabled", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Unix()),
|
||||
ValidBefore: uint64(now.Add(time.Hour).Unix()),
|
||||
}}, true},
|
||||
{"fail not yet valid", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
|
||||
}}, true},
|
||||
{"fail expired", args{ctx, &Controller{
|
||||
Interface: &JWK{},
|
||||
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
|
||||
}, &ssh.Certificate{
|
||||
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
|
||||
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
|
||||
}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
|
||||
t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
73
authority/provisioner/extension.go
Normal file
73
authority/provisioner/extension.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
)
|
||||
|
||||
var (
|
||||
// StepOIDRoot is the root OID for smallstep.
|
||||
StepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
||||
|
||||
// StepOIDProvisioner is the OID for the provisioner extension.
|
||||
StepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(StepOIDRoot, 1)...)
|
||||
)
|
||||
|
||||
// Extension is the Go representation of the provisioner extension.
|
||||
type Extension struct {
|
||||
Type Type
|
||||
Name string
|
||||
CredentialID string
|
||||
KeyValuePairs []string
|
||||
}
|
||||
|
||||
type extensionASN1 struct {
|
||||
Type int
|
||||
Name []byte
|
||||
CredentialID []byte
|
||||
KeyValuePairs []string `asn1:"optional,omitempty"`
|
||||
}
|
||||
|
||||
// Marshal marshals the extension using encoding/asn1.
|
||||
func (e *Extension) Marshal() ([]byte, error) {
|
||||
return asn1.Marshal(extensionASN1{
|
||||
Type: int(e.Type),
|
||||
Name: []byte(e.Name),
|
||||
CredentialID: []byte(e.CredentialID),
|
||||
KeyValuePairs: e.KeyValuePairs,
|
||||
})
|
||||
}
|
||||
|
||||
// ToExtension returns the pkix.Extension representation of the provisioner
|
||||
// extension.
|
||||
func (e *Extension) ToExtension() (pkix.Extension, error) {
|
||||
b, err := e.Marshal()
|
||||
if err != nil {
|
||||
return pkix.Extension{}, err
|
||||
}
|
||||
return pkix.Extension{
|
||||
Id: StepOIDProvisioner,
|
||||
Value: b,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetProvisionerExtension goes through all the certificate extensions and
|
||||
// returns the provisioner extension (1.3.6.1.4.1.37476.9000.64.1).
|
||||
func GetProvisionerExtension(cert *x509.Certificate) (*Extension, bool) {
|
||||
for _, e := range cert.Extensions {
|
||||
if e.Id.Equal(StepOIDProvisioner) {
|
||||
var provisioner extensionASN1
|
||||
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return &Extension{
|
||||
Type: Type(provisioner.Type),
|
||||
Name: string(provisioner.Name),
|
||||
CredentialID: string(provisioner.CredentialID),
|
||||
KeyValuePairs: provisioner.KeyValuePairs,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
158
authority/provisioner/extension_test.go
Normal file
158
authority/provisioner/extension_test.go
Normal file
|
@ -0,0 +1,158 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"go.step.sm/crypto/pemutil"
|
||||
)
|
||||
|
||||
func TestExtension_Marshal(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Name string
|
||||
CredentialID string
|
||||
KeyValuePairs []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{TypeJWK, "name", "credentialID", nil}, []byte{
|
||||
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
|
||||
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
|
||||
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
|
||||
0x44,
|
||||
}, false},
|
||||
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, []byte{
|
||||
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
|
||||
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
|
||||
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
|
||||
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
|
||||
0x13, 0x03, 0x62, 0x61, 0x72,
|
||||
}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &Extension{
|
||||
Type: tt.fields.Type,
|
||||
Name: tt.fields.Name,
|
||||
CredentialID: tt.fields.CredentialID,
|
||||
KeyValuePairs: tt.fields.KeyValuePairs,
|
||||
}
|
||||
got, err := e.Marshal()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Extension.Marshal() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Extension.Marshal() = %x, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtension_ToExtension(t *testing.T) {
|
||||
type fields struct {
|
||||
Type Type
|
||||
Name string
|
||||
CredentialID string
|
||||
KeyValuePairs []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want pkix.Extension
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", fields{TypeJWK, "name", "credentialID", nil}, pkix.Extension{
|
||||
Id: StepOIDProvisioner,
|
||||
Value: []byte{
|
||||
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
|
||||
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
|
||||
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
|
||||
0x44,
|
||||
},
|
||||
}, false},
|
||||
{"ok empty pairs", fields{TypeJWK, "name", "credentialID", []string{}}, pkix.Extension{
|
||||
Id: StepOIDProvisioner,
|
||||
Value: []byte{
|
||||
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
|
||||
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
|
||||
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
|
||||
0x44,
|
||||
},
|
||||
}, false},
|
||||
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, pkix.Extension{
|
||||
Id: StepOIDProvisioner,
|
||||
Value: []byte{
|
||||
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
|
||||
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
|
||||
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
|
||||
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
|
||||
0x13, 0x03, 0x62, 0x61, 0x72,
|
||||
},
|
||||
}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &Extension{
|
||||
Type: tt.fields.Type,
|
||||
Name: tt.fields.Name,
|
||||
CredentialID: tt.fields.CredentialID,
|
||||
KeyValuePairs: tt.fields.KeyValuePairs,
|
||||
}
|
||||
got, err := e.ToExtension()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Extension.ToExtension() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Extension.ToExtension() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProvisionerExtension(t *testing.T) {
|
||||
mustCertificate := func(fn string) *x509.Certificate {
|
||||
cert, err := pemutil.ReadCertificate(fn)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return cert
|
||||
}
|
||||
|
||||
type args struct {
|
||||
cert *x509.Certificate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Extension
|
||||
want1 bool
|
||||
}{
|
||||
{"ok", args{mustCertificate("testdata/certs/good-extension.crt")}, &Extension{
|
||||
Type: TypeJWK,
|
||||
Name: "mariano@smallstep.com",
|
||||
CredentialID: "nvgnR8wSzpUlrt_tC3mvrhwhBx9Y7T1WL_JjcFVWYBQ",
|
||||
}, true},
|
||||
{"fail unmarshal", args{mustCertificate("testdata/certs/bad-extension.crt")}, nil, false},
|
||||
{"missing extension", args{mustCertificate("testdata/certs/aws.crt")}, nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1 := GetProvisionerExtension(tt.args.cert)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("GetProvisionerExtension() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("GetProvisionerExtension() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
|
@ -88,12 +89,11 @@ type GCP struct {
|
|||
InstanceAge Duration `json:"instanceAge,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
claimer *Claimer
|
||||
config *gcpConfig
|
||||
keyStore *keyStore
|
||||
audiences Audiences
|
||||
x509Policy x509PolicyEngine
|
||||
sshHostPolicy *hostPolicyEngine
|
||||
ctl *Controller
|
||||
x509Policy policy.X509Policy
|
||||
sshHostPolicy policy.HostPolicy
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier. The name should uniquely
|
||||
|
@ -196,8 +196,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
|
|||
}
|
||||
|
||||
// Init validates and initializes the GCP provisioner.
|
||||
func (p *GCP) Init(config Config) error {
|
||||
var err error
|
||||
func (p *GCP) Init(config Config) (err error) {
|
||||
switch {
|
||||
case p.Type == "":
|
||||
return errors.New("provisioner type cannot be empty")
|
||||
|
@ -206,30 +205,28 @@ func (p *GCP) Init(config Config) error {
|
|||
case p.InstanceAge.Value() < 0:
|
||||
return errors.New("provisioner instanceAge cannot be negative")
|
||||
}
|
||||
|
||||
// Initialize config
|
||||
p.assertConfig()
|
||||
// Update claims with global ones
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize key store
|
||||
p.keyStore, err = newKeyStore(p.config.CertsURL)
|
||||
if err != nil {
|
||||
return err
|
||||
if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize the x509 allow/deny policy engine
|
||||
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for host certificates
|
||||
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
return nil
|
||||
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
return
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token and returns the sign options that
|
||||
|
@ -281,20 +278,17 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
|
||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||
// validators
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.x509Policy),
|
||||
), nil
|
||||
}
|
||||
|
||||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||
}
|
||||
|
||||
// assertConfig initializes the config if it has not been initialized.
|
||||
|
@ -341,7 +335,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
}
|
||||
|
||||
// validate audiences with the defaults
|
||||
if !matchesAudience(claims.Audience, p.audiences.Sign) {
|
||||
if !matchesAudience(claims.Audience, p.ctl.Audiences.Sign) {
|
||||
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
|
||||
}
|
||||
|
||||
|
@ -396,7 +390,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
|
|||
|
||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName())
|
||||
}
|
||||
claims, err := p.authorizeToken(token)
|
||||
|
@ -444,11 +438,11 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
// Validate user SignSSHOptions.
|
||||
sshCertOptionsValidator(defaults),
|
||||
// Set the validity bounds if not set.
|
||||
&sshDefaultDuration{p.claimer},
|
||||
&sshDefaultDuration{p.ctl.Claimer},
|
||||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
|
|
|
@ -549,18 +549,18 @@ func TestGCP_AuthorizeSign(t *testing.T) {
|
|||
switch v := o.(type) {
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, int(TypeGCP))
|
||||
assert.Equals(t, v.Type, TypeGCP)
|
||||
assert.Equals(t, v.Name, tt.gcp.GetName())
|
||||
assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0])
|
||||
assert.Len(t, 4, v.KeyValuePairs)
|
||||
case profileDefaultDuration:
|
||||
assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration())
|
||||
assert.Equals(t, time.Duration(v), tt.gcp.ctl.Claimer.DefaultTLSCertDuration())
|
||||
case commonNameSliceValidator:
|
||||
assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
|
||||
case defaultPublicKeyValidator:
|
||||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration())
|
||||
assert.Equals(t, v.min, tt.gcp.ctl.Claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.gcp.ctl.Claimer.MaxTLSCertDuration())
|
||||
case ipAddressesValidator:
|
||||
assert.Equals(t, v, nil)
|
||||
case emailAddressesValidator:
|
||||
|
@ -597,7 +597,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
|
|||
// disable sshCA
|
||||
disable := false
|
||||
p3.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
||||
p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
t1, err := generateGCPToken(p1.ServiceAccounts[0],
|
||||
|
@ -624,7 +624,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
|
|||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||
expectedHostOptions := &SignSSHOptions{
|
||||
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
|
||||
|
@ -700,6 +700,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGCP_AuthorizeRenew(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Second)
|
||||
p1, err := generateGCP()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateGCP()
|
||||
|
@ -708,7 +709,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
|
|||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
|
@ -721,8 +722,14 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
|
|||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
||||
{"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
||||
{"ok", p1, args{&x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, http.StatusOK, false},
|
||||
{"fail/renewal-disabled", p2, args{&x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -7,10 +7,13 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// jwtPayload extends jwt.Claims with step attributes.
|
||||
|
@ -35,11 +38,10 @@ type JWK struct {
|
|||
EncryptedKey string `json:"encryptedKey,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
claimer *Claimer
|
||||
audiences Audiences
|
||||
x509Policy x509PolicyEngine
|
||||
sshHostPolicy *hostPolicyEngine
|
||||
sshUserPolicy *userPolicyEngine
|
||||
ctl *Controller
|
||||
x509Policy policy.X509Policy
|
||||
sshHostPolicy policy.HostPolicy
|
||||
sshUserPolicy policy.UserPolicy
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier. The name and credential id
|
||||
|
@ -101,28 +103,23 @@ func (p *JWK) Init(config Config) (err error) {
|
|||
return errors.New("provisioner key cannot be empty")
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the x509 allow/deny policy engine
|
||||
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for user certificates
|
||||
if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for host certificates
|
||||
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.audiences = config.Audiences
|
||||
return err
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
return
|
||||
}
|
||||
|
||||
// authorizeToken performs common jwt authorization actions and returns the
|
||||
|
@ -164,14 +161,14 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
|
|||
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||
// revoke the certificate with serial number in the `sub` property.
|
||||
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
_, err := p.authorizeToken(token, p.audiences.Revoke)
|
||||
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
|
||||
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke)
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
claims, err := p.authorizeToken(token, p.audiences.Sign)
|
||||
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
|
||||
}
|
||||
|
@ -198,12 +195,12 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
|
||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||
// validators
|
||||
commonNameValidator(claims.Subject),
|
||||
defaultPublicKeyValidator{},
|
||||
defaultSANsValidator(claims.SANs),
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.x509Policy),
|
||||
}, nil
|
||||
}
|
||||
|
@ -213,12 +210,13 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
// revocation status. Just confirms that the provisioner that created the
|
||||
// certificate was configured to allow renewals.
|
||||
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName())
|
||||
}
|
||||
// if p.claimer.IsDisableRenewal() {
|
||||
// return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName())
|
||||
// }
|
||||
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew)
|
||||
//return p.authorizeRenew(cert)
|
||||
return nil
|
||||
// return nil
|
||||
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||
}
|
||||
|
||||
// func (p *JWK) authorizeRenew(cert *x509.Certificate) error {
|
||||
|
@ -231,10 +229,10 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
|
|||
|
||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName())
|
||||
}
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
||||
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
|
||||
}
|
||||
|
@ -291,11 +289,11 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
|
||||
return append(signOptions,
|
||||
// Set the validity bounds if not set.
|
||||
&sshDefaultDuration{p.claimer},
|
||||
&sshDefaultDuration{p.ctl.Claimer},
|
||||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require and validate all the default fields in the SSH certificate.
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
|
@ -305,7 +303,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
|
||||
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
|
||||
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||
_, err := p.authorizeToken(token, p.audiences.SSHRevoke)
|
||||
_, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke)
|
||||
// TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke)
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
|
||||
}
|
||||
|
|
|
@ -76,13 +76,13 @@ func TestJWK_Init(t *testing.T) {
|
|||
},
|
||||
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}},
|
||||
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, Claims: &Claims{DefaultTLSDur: &Duration{0}}},
|
||||
err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences},
|
||||
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -300,18 +300,18 @@ func TestJWK_AuthorizeSign(t *testing.T) {
|
|||
switch v := o.(type) {
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, int(TypeJWK))
|
||||
assert.Equals(t, v.Type, TypeJWK)
|
||||
assert.Equals(t, v.Name, tt.prov.GetName())
|
||||
assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID)
|
||||
assert.Len(t, 0, v.KeyValuePairs)
|
||||
case profileDefaultDuration:
|
||||
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
|
||||
assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
|
||||
case commonNameValidator:
|
||||
assert.Equals(t, string(v), "subject")
|
||||
case defaultPublicKeyValidator:
|
||||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
|
||||
assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
|
||||
case defaultSANsValidator:
|
||||
assert.Equals(t, []string(v), tt.sans)
|
||||
case *x509NamePolicyValidator:
|
||||
|
@ -327,6 +327,7 @@ func TestJWK_AuthorizeSign(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestJWK_AuthorizeRenew(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Second)
|
||||
p1, err := generateJWK()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateJWK()
|
||||
|
@ -335,7 +336,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
|
|||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
|
@ -348,8 +349,14 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
|
|||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
||||
{"ok", p1, args{&x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{&x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -375,7 +382,7 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
|
|||
// disable sshCA
|
||||
disable := false
|
||||
p2.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
jwk, err := decryptJSONWebKey(p1.EncryptedKey)
|
||||
|
@ -404,8 +411,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
|
|||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
|
||||
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||
expectedUserOptions := &SignSSHOptions{
|
||||
CertType: "user", Principals: []string{"name"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||
|
@ -487,8 +494,8 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
|
|||
signer, err := generateJSONWebKey()
|
||||
assert.FatalError(t, err)
|
||||
|
||||
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
|
||||
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||
expectedUserOptions := &SignSSHOptions{
|
||||
CertType: "user", Principals: []string{"name"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||
|
|
|
@ -10,11 +10,14 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/pemutil"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// NOTE: There can be at most one kubernetes service account provisioner configured
|
||||
|
@ -48,13 +51,12 @@ type K8sSA struct {
|
|||
PubKeys []byte `json:"publicKeys,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
claimer *Claimer
|
||||
audiences Audiences
|
||||
//kauthn kauthn.AuthenticationV1Interface
|
||||
pubKeys []interface{}
|
||||
x509Policy x509PolicyEngine
|
||||
sshHostPolicy *hostPolicyEngine
|
||||
sshUserPolicy *userPolicyEngine
|
||||
ctl *Controller
|
||||
x509Policy policy.X509Policy
|
||||
sshHostPolicy policy.HostPolicy
|
||||
sshUserPolicy policy.UserPolicy
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier. The name and credential id
|
||||
|
@ -142,28 +144,23 @@ func (p *K8sSA) Init(config Config) (err error) {
|
|||
p.kauthn = k8s.AuthenticationV1()
|
||||
*/
|
||||
|
||||
// Update claims with global ones
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the x509 allow/deny policy engine
|
||||
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for user certificates
|
||||
if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for host certificates
|
||||
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.audiences = config.Audiences
|
||||
return err
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
return
|
||||
}
|
||||
|
||||
// authorizeToken performs common jwt authorization actions and returns the
|
||||
|
@ -230,13 +227,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
|
|||
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||
// revoke the certificate with serial number in the `sub` property.
|
||||
func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
_, err := p.authorizeToken(token, p.audiences.Revoke)
|
||||
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke")
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
claims, err := p.authorizeToken(token, p.audiences.Sign)
|
||||
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign")
|
||||
}
|
||||
|
@ -259,28 +256,25 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeK8sSA, p.Name, ""),
|
||||
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()),
|
||||
profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
|
||||
// validators
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.x509Policy),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||
}
|
||||
|
||||
// AuthorizeSSHSign validates an request for an SSH certificate.
|
||||
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName())
|
||||
}
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
||||
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
|
||||
}
|
||||
|
@ -302,11 +296,11 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
|
|||
// Require type, key-id and principals in the SignSSHOptions.
|
||||
&sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true},
|
||||
// Set the validity bounds if not set.
|
||||
&sshDefaultDuration{p.claimer},
|
||||
&sshDefaultDuration{p.ctl.Claimer},
|
||||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require and validate all the default fields in the SSH certificate.
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
|
|
|
@ -179,6 +179,7 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestK8sSA_AuthorizeRenew(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Second)
|
||||
type test struct {
|
||||
p *K8sSA
|
||||
cert *x509.Certificate
|
||||
|
@ -192,13 +193,16 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
|
|||
// disable renewal
|
||||
disable := true
|
||||
p.Claims = &Claims{DisableRenewal: &disable}
|
||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
cert: &x509.Certificate{},
|
||||
cert: &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
},
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()),
|
||||
err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
@ -206,7 +210,10 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
|
|||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
cert: &x509.Certificate{},
|
||||
cert: &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
@ -275,16 +282,16 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
|
|||
switch v := o.(type) {
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, int(TypeK8sSA))
|
||||
assert.Equals(t, v.Type, TypeK8sSA)
|
||||
assert.Equals(t, v.Name, tc.p.GetName())
|
||||
assert.Equals(t, v.CredentialID, "")
|
||||
assert.Len(t, 0, v.KeyValuePairs)
|
||||
case profileDefaultDuration:
|
||||
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration())
|
||||
assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
|
||||
case defaultPublicKeyValidator:
|
||||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
|
||||
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
|
||||
case *x509NamePolicyValidator:
|
||||
assert.Equals(t, nil, v.policyEngine)
|
||||
default:
|
||||
|
@ -313,7 +320,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
|
|||
// disable sshCA
|
||||
disable := false
|
||||
p.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
|
@ -365,11 +372,11 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
|
|||
case *sshCertOptionsRequireValidator:
|
||||
assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true})
|
||||
case *sshCertValidityValidator:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||
case *sshDefaultPublicKeyValidator:
|
||||
case *sshCertDefaultValidator:
|
||||
case *sshDefaultDuration:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||
case *sshNamePolicyValidator:
|
||||
assert.Equals(t, nil, v.userPolicyEngine)
|
||||
assert.Equals(t, nil, v.hostPolicyEngine)
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
nebula "github.com/slackhq/nebula/cert"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
|
@ -40,16 +41,15 @@ type Nebula struct {
|
|||
Roots []byte `json:"roots"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
claimer *Claimer
|
||||
caPool *nebula.NebulaCAPool
|
||||
audiences Audiences
|
||||
x509Policy x509PolicyEngine
|
||||
sshHostPolicy *hostPolicyEngine
|
||||
sshUserPolicy *userPolicyEngine
|
||||
ctl *Controller
|
||||
x509Policy policy.X509Policy
|
||||
sshHostPolicy policy.HostPolicy
|
||||
sshUserPolicy policy.UserPolicy
|
||||
}
|
||||
|
||||
// Init verifies and initializes the Nebula provisioner.
|
||||
func (p *Nebula) Init(config Config) error {
|
||||
func (p *Nebula) Init(config Config) (err error) {
|
||||
switch {
|
||||
case p.Type == "":
|
||||
return errors.New("provisioner type cannot be empty")
|
||||
|
@ -59,34 +59,29 @@ func (p *Nebula) Init(config Config) error {
|
|||
return errors.New("provisioner root(s) cannot be empty")
|
||||
}
|
||||
|
||||
var err error
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots)
|
||||
if err != nil {
|
||||
return errs.InternalServer("failed to create ca pool: %v", err)
|
||||
}
|
||||
|
||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
|
||||
// Initialize the x509 allow/deny policy engine
|
||||
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for user certificates
|
||||
if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for host certificates
|
||||
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
return
|
||||
}
|
||||
|
||||
// GetID returns the provisioner id.
|
||||
|
@ -138,7 +133,7 @@ func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) {
|
|||
|
||||
// AuthorizeSign returns the list of SignOption for a Sign request.
|
||||
func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
crt, claims, err := p.authorizeToken(token, p.audiences.Sign)
|
||||
crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -172,7 +167,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeNebula, p.Name, ""),
|
||||
profileLimitDuration{
|
||||
def: p.claimer.DefaultTLSCertDuration(),
|
||||
def: p.ctl.Claimer.DefaultTLSCertDuration(),
|
||||
notBefore: crt.Details.NotBefore,
|
||||
notAfter: crt.Details.NotAfter,
|
||||
},
|
||||
|
@ -183,7 +178,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
IPs: crt.Details.Ips,
|
||||
},
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.x509Policy),
|
||||
}, nil
|
||||
}
|
||||
|
@ -191,11 +186,11 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
|
|||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
// Currently the Nebula provisioner only grants host SSH certificates.
|
||||
func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
|
||||
}
|
||||
|
||||
crt, claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
||||
crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -273,11 +268,11 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
|
|||
return append(signOptions,
|
||||
templateOptions,
|
||||
// Checks the validity bounds, and set the validity if has not been set.
|
||||
&sshLimitDuration{p.claimer, crt.Details.NotAfter},
|
||||
&sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter},
|
||||
// Validate public key.
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
|
@ -287,23 +282,20 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
|
|||
|
||||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("renew is disabled for nebula provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
return p.ctl.AuthorizeRenew(ctx, crt)
|
||||
}
|
||||
|
||||
// AuthorizeRevoke returns an error if the token is not valid.
|
||||
func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
return p.validateToken(token, p.audiences.Revoke)
|
||||
return p.validateToken(token, p.ctl.Audiences.Revoke)
|
||||
}
|
||||
|
||||
// AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid.
|
||||
func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||
return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
|
||||
}
|
||||
if _, _, err := p.authorizeToken(token, p.audiences.SSHRevoke); err != nil {
|
||||
if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -327,7 +327,7 @@ func TestNebula_GetIDForToken(t *testing.T) {
|
|||
func TestNebula_GetTokenID(t *testing.T) {
|
||||
p, ca, signer := mustNebulaProvisioner(t)
|
||||
c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer)
|
||||
t1 := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv)
|
||||
t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv)
|
||||
_, claims, err := parseToken(t1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -441,8 +441,8 @@ func TestNebula_AuthorizeSign(t *testing.T) {
|
|||
ctx := context.TODO()
|
||||
p, ca, signer := mustNebulaProvisioner(t)
|
||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv)
|
||||
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), nil, crt, priv)
|
||||
ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv)
|
||||
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), nil, crt, priv)
|
||||
|
||||
pBadOptions, _, _ := mustNebulaProvisioner(t)
|
||||
pBadOptions.caPool = p.caPool
|
||||
|
@ -483,20 +483,20 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
|
|||
// Ok provisioner
|
||||
p, ca, signer := mustNebulaProvisioner(t)
|
||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||
CertType: "host",
|
||||
KeyID: "test.lan",
|
||||
Principals: []string{"test.lan", "10.1.0.1"},
|
||||
}, crt, priv)
|
||||
okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), nil, crt, priv)
|
||||
okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||
okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), nil, crt, priv)
|
||||
okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||
ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)),
|
||||
ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)),
|
||||
}, crt, priv)
|
||||
failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||
failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||
CertType: "user",
|
||||
}, crt, priv)
|
||||
failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||
failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
|
||||
CertType: "host",
|
||||
KeyID: "test.lan",
|
||||
Principals: []string{"test.lan", "10.1.0.1", "foo.bar"},
|
||||
|
@ -549,6 +549,8 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
|
|||
|
||||
func TestNebula_AuthorizeRenew(t *testing.T) {
|
||||
ctx := context.TODO()
|
||||
now := time.Now().Truncate(time.Second)
|
||||
|
||||
// Ok provisioner
|
||||
p, _, _ := mustNebulaProvisioner(t)
|
||||
|
||||
|
@ -567,8 +569,14 @@ func TestNebula_AuthorizeRenew(t *testing.T) {
|
|||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p, args{ctx, &x509.Certificate{}}, false},
|
||||
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{}}, true},
|
||||
{"ok", p, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, false},
|
||||
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -584,12 +592,12 @@ func TestNebula_AuthorizeRevoke(t *testing.T) {
|
|||
// Ok provisioner
|
||||
p, ca, signer := mustNebulaProvisioner(t)
|
||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv)
|
||||
ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
|
||||
|
||||
// Fail different CA
|
||||
nc, signer := mustNebulaCA(t)
|
||||
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
|
||||
failToken := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv)
|
||||
failToken := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
|
@ -618,12 +626,12 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
|
|||
// Ok provisioner
|
||||
p, ca, signer := mustNebulaProvisioner(t)
|
||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv)
|
||||
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
|
||||
|
||||
// Fail different CA
|
||||
nc, signer := mustNebulaCA(t)
|
||||
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
|
||||
failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv)
|
||||
failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
|
||||
|
||||
// Provisioner with SSH disabled
|
||||
var bFalse bool
|
||||
|
@ -657,7 +665,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
|
|||
func TestNebula_AuthorizeSSHRenew(t *testing.T) {
|
||||
p, ca, signer := mustNebulaProvisioner(t)
|
||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRenew[0], now(), nil, crt, priv)
|
||||
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRenew[0], now(), nil, crt, priv)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
|
@ -689,7 +697,7 @@ func TestNebula_AuthorizeSSHRenew(t *testing.T) {
|
|||
func TestNebula_AuthorizeSSHRekey(t *testing.T) {
|
||||
p, ca, signer := mustNebulaProvisioner(t)
|
||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRekey[0], now(), nil, crt, priv)
|
||||
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRekey[0], now(), nil, crt, priv)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
|
@ -726,20 +734,20 @@ func TestNebula_authorizeToken(t *testing.T) {
|
|||
t1 := now()
|
||||
p, ca, signer := mustNebulaProvisioner(t)
|
||||
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
|
||||
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv)
|
||||
okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{
|
||||
ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv)
|
||||
okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{
|
||||
CertType: "host",
|
||||
KeyID: "test.lan",
|
||||
Principals: []string{"test.lan"},
|
||||
}, crt, priv)
|
||||
okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv)
|
||||
okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv)
|
||||
|
||||
// Token with errors
|
||||
failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv)
|
||||
failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||
failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv)
|
||||
failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||
failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv)
|
||||
failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||
failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
|
||||
|
||||
// Not a nebula token
|
||||
jwk, err := generateJSONWebKey()
|
||||
|
@ -761,7 +769,7 @@ func TestNebula_authorizeToken(t *testing.T) {
|
|||
IssuedAt: jose.NewNumericDate(t1),
|
||||
NotBefore: jose.NewNumericDate(t1),
|
||||
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
|
||||
Audience: []string{p.audiences.Sign[0]},
|
||||
Audience: []string{p.ctl.Audiences.Sign[0]},
|
||||
}
|
||||
sshClaims := jose.Claims{
|
||||
ID: "[REPLACEME]",
|
||||
|
@ -770,7 +778,7 @@ func TestNebula_authorizeToken(t *testing.T) {
|
|||
IssuedAt: jose.NewNumericDate(t1),
|
||||
NotBefore: jose.NewNumericDate(t1),
|
||||
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
|
||||
Audience: []string{p.audiences.SSHSign[0]},
|
||||
Audience: []string{p.ctl.Audiences.SSHSign[0]},
|
||||
}
|
||||
|
||||
type args struct {
|
||||
|
@ -785,14 +793,14 @@ func TestNebula_authorizeToken(t *testing.T) {
|
|||
want1 *jwtPayload
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{
|
||||
{"ok x509", p, args{ok, p.ctl.Audiences.Sign}, crt, &jwtPayload{
|
||||
Claims: x509Claims,
|
||||
SANs: []string{"10.1.0.1"},
|
||||
}, false},
|
||||
{"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{
|
||||
{"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, crt, &jwtPayload{
|
||||
Claims: x509Claims,
|
||||
}, false},
|
||||
{"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{
|
||||
{"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
|
||||
Claims: sshClaims,
|
||||
Step: &stepPayload{
|
||||
SSH: &SignSSHOptions{
|
||||
|
@ -802,16 +810,16 @@ func TestNebula_authorizeToken(t *testing.T) {
|
|||
},
|
||||
},
|
||||
}, false},
|
||||
{"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{
|
||||
{"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
|
||||
Claims: sshClaims,
|
||||
}, false},
|
||||
{"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true},
|
||||
{"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true},
|
||||
{"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true},
|
||||
{"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true},
|
||||
{"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true},
|
||||
{"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true},
|
||||
{"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true},
|
||||
{"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, nil, true},
|
||||
{"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||
{"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||
{"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||
{"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||
{"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||
{"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
|
|
@ -12,10 +12,13 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// openIDConfiguration contains the necessary properties in the
|
||||
|
@ -92,11 +95,11 @@ type OIDC struct {
|
|||
Options *Options `json:"options,omitempty"`
|
||||
configuration openIDConfiguration
|
||||
keyStore *keyStore
|
||||
claimer *Claimer
|
||||
getIdentityFunc GetIdentityFunc
|
||||
x509Policy x509PolicyEngine
|
||||
sshHostPolicy *hostPolicyEngine
|
||||
sshUserPolicy *userPolicyEngine
|
||||
ctl *Controller
|
||||
x509Policy policy.X509Policy
|
||||
sshHostPolicy policy.HostPolicy
|
||||
sshUserPolicy policy.UserPolicy
|
||||
}
|
||||
|
||||
func sanitizeEmail(email string) string {
|
||||
|
@ -175,11 +178,6 @@ func (o *OIDC) Init(config Config) (err error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode and validate openid-configuration endpoint
|
||||
u, err := url.Parse(o.ConfigurationEndpoint)
|
||||
if err != nil {
|
||||
|
@ -212,21 +210,22 @@ func (o *OIDC) Init(config Config) (err error) {
|
|||
}
|
||||
|
||||
// Initialize the x509 allow/deny policy engine
|
||||
if o.x509Policy, err = newX509PolicyEngine(o.Options.GetX509Options()); err != nil {
|
||||
if o.x509Policy, err = policy.NewX509PolicyEngine(o.Options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for user certificates
|
||||
if o.sshUserPolicy, err = newSSHUserPolicyEngine(o.Options.GetSSHOptions()); err != nil {
|
||||
if o.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(o.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for host certificates
|
||||
if o.sshHostPolicy, err = newSSHHostPolicyEngine(o.Options.GetSSHOptions()); err != nil {
|
||||
if o.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(o.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
o.ctl, err = NewController(o, o.Claims, config)
|
||||
return
|
||||
}
|
||||
|
||||
// ValidatePayload validates the given token payload.
|
||||
|
@ -378,10 +377,10 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
|
||||
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()),
|
||||
profileDefaultDuration(o.ctl.Claimer.DefaultTLSCertDuration()),
|
||||
// validators
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()),
|
||||
newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(o.x509Policy),
|
||||
}, nil
|
||||
}
|
||||
|
@ -391,15 +390,12 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
// revocation status. Just confirms that the provisioner that created the
|
||||
// certificate was configured to allow renewals.
|
||||
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if o.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName())
|
||||
}
|
||||
return nil
|
||||
return o.ctl.AuthorizeRenew(ctx, cert)
|
||||
}
|
||||
|
||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !o.claimer.IsSSHCAEnabled() {
|
||||
if !o.ctl.Claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName())
|
||||
}
|
||||
claims, err := o.authorizeToken(token)
|
||||
|
@ -414,7 +410,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
|||
// Get the identity using either the default identityFunc or one injected
|
||||
// externally. Note that the PreferredUsername might be empty.
|
||||
// TBD: Would preferred_username present a safety issue here?
|
||||
iden, err := o.getIdentityFunc(ctx, o, claims.Email)
|
||||
iden, err := o.ctl.GetIdentity(ctx, claims.Email)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
|
||||
}
|
||||
|
@ -465,11 +461,11 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
|
|||
|
||||
return append(signOptions,
|
||||
// Set the validity bounds if not set.
|
||||
&sshDefaultDuration{o.claimer},
|
||||
&sshDefaultDuration{o.ctl.Claimer},
|
||||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertValidityValidator{o.claimer},
|
||||
&sshCertValidityValidator{o.ctl.Claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
|
|
|
@ -327,16 +327,16 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
|
|||
switch v := o.(type) {
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, int(TypeOIDC))
|
||||
assert.Equals(t, v.Type, TypeOIDC)
|
||||
assert.Equals(t, v.Name, tt.prov.GetName())
|
||||
assert.Equals(t, v.CredentialID, tt.prov.ClientID)
|
||||
assert.Len(t, 0, v.KeyValuePairs)
|
||||
case profileDefaultDuration:
|
||||
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration())
|
||||
assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
|
||||
case defaultPublicKeyValidator:
|
||||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration())
|
||||
assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
|
||||
case emailOnlyIdentity:
|
||||
assert.Equals(t, string(v), "name@smallstep.com")
|
||||
case *x509NamePolicyValidator:
|
||||
|
@ -413,6 +413,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestOIDC_AuthorizeRenew(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Second)
|
||||
p1, err := generateOIDC()
|
||||
assert.FatalError(t, err)
|
||||
p2, err := generateOIDC()
|
||||
|
@ -421,7 +422,7 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
|
|||
// disable renewal
|
||||
disable := true
|
||||
p2.Claims = &Claims{DisableRenewal: &disable}
|
||||
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
type args struct {
|
||||
|
@ -434,8 +435,14 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
|
|||
code int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", p1, args{nil}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true},
|
||||
{"ok", p1, args{&x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, http.StatusOK, false},
|
||||
{"fail/renew-disabled", p2, args{&x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}}, http.StatusUnauthorized, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -480,7 +487,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
|||
// disable sshCA
|
||||
disable := false
|
||||
p6.Claims = &Claims{EnableSSHCA: &disable}
|
||||
p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
|
||||
p6.ctl.Claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
// Update configuration endpoints and initialize
|
||||
|
@ -496,10 +503,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
|||
assert.FatalError(t, p4.Init(config))
|
||||
assert.FatalError(t, p5.Init(config))
|
||||
|
||||
p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
p4.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
return &Identity{Usernames: []string{"max", "mariano"}}, nil
|
||||
}
|
||||
p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
p5.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
return nil, errors.New("force")
|
||||
}
|
||||
// Additional test needed for empty usernames and duplicate email and usernames
|
||||
|
@ -529,8 +536,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
|
|||
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
assert.FatalError(t, err)
|
||||
|
||||
userDuration := p1.claimer.DefaultUserSSHCertDuration()
|
||||
hostDuration := p1.claimer.DefaultHostSSHCertDuration()
|
||||
userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
|
||||
hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
|
||||
expectedUserOptions := &SignSSHOptions{
|
||||
CertType: "user", Principals: []string{"name", "name@smallstep.com"},
|
||||
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
|
||||
|
|
|
@ -5,8 +5,11 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
)
|
||||
|
||||
// CertificateOptions is an interface that returns a list of options passed when
|
||||
|
@ -58,10 +61,10 @@ type X509Options struct {
|
|||
TemplateData json.RawMessage `json:"templateData,omitempty"`
|
||||
|
||||
// AllowedNames contains the SANs the provisioner is authorized to sign
|
||||
AllowedNames *X509NameOptions `json:"allow,omitempty"`
|
||||
AllowedNames *policy.X509NameOptions
|
||||
|
||||
// DeniedNames contains the SANs the provisioner is not authorized to sign
|
||||
DeniedNames *X509NameOptions `json:"deny,omitempty"`
|
||||
DeniedNames *policy.X509NameOptions
|
||||
}
|
||||
|
||||
// HasTemplate returns true if a template is defined in the provisioner options.
|
||||
|
@ -69,41 +72,24 @@ func (o *X509Options) HasTemplate() bool {
|
|||
return o != nil && (o.Template != "" || o.TemplateFile != "")
|
||||
}
|
||||
|
||||
// GetAllowedNameOptions returns the AllowedNameOptions, which models the
|
||||
// GetAllowedNameOptions returns the AllowedNames, which models the
|
||||
// SANs that a provisioner is authorized to sign x509 certificates for.
|
||||
func (o *X509Options) GetAllowedNameOptions() *X509NameOptions {
|
||||
func (o *X509Options) GetAllowedNameOptions() *policy.X509NameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.AllowedNames
|
||||
}
|
||||
|
||||
// GetDeniedNameOptions returns the DeniedNameOptions, which models the
|
||||
// GetDeniedNameOptions returns the DeniedNames, which models the
|
||||
// SANs that a provisioner is NOT authorized to sign x509 certificates for.
|
||||
func (o *X509Options) GetDeniedNameOptions() *X509NameOptions {
|
||||
func (o *X509Options) GetDeniedNameOptions() *policy.X509NameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.DeniedNames
|
||||
}
|
||||
|
||||
// X509NameOptions models the X509 name policy configuration.
|
||||
type X509NameOptions struct {
|
||||
DNSDomains []string `json:"dns,omitempty"`
|
||||
IPRanges []string `json:"ip,omitempty"`
|
||||
EmailAddresses []string `json:"email,omitempty"`
|
||||
URIDomains []string `json:"uri,omitempty"`
|
||||
}
|
||||
|
||||
// HasNames checks if the AllowedNameOptions has one or more
|
||||
// names configured.
|
||||
func (o *X509NameOptions) HasNames() bool {
|
||||
return len(o.DNSDomains) > 0 ||
|
||||
len(o.IPRanges) > 0 ||
|
||||
len(o.EmailAddresses) > 0 ||
|
||||
len(o.URIDomains) > 0
|
||||
}
|
||||
|
||||
// TemplateOptions generates a CertificateOptions with the template and data
|
||||
// defined in the ProvisionerOptions, the provisioner generated data, and the
|
||||
// user data provided in the request. If no template has been provided,
|
||||
|
|
|
@ -1,156 +0,0 @@
|
|||
package provisioner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/smallstep/certificates/policy"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type sshPolicyEngineType string
|
||||
|
||||
const (
|
||||
userPolicyEngineType sshPolicyEngineType = "user"
|
||||
hostPolicyEngineType sshPolicyEngineType = "host"
|
||||
)
|
||||
|
||||
var certTypeToPolicyEngineType = map[uint32]sshPolicyEngineType{
|
||||
uint32(ssh.UserCert): userPolicyEngineType,
|
||||
uint32(ssh.HostCert): hostPolicyEngineType,
|
||||
}
|
||||
|
||||
type x509PolicyEngine interface {
|
||||
policy.X509NamePolicyEngine
|
||||
}
|
||||
|
||||
type userPolicyEngine struct {
|
||||
policy.SSHNamePolicyEngine
|
||||
}
|
||||
|
||||
type hostPolicyEngine struct {
|
||||
policy.SSHNamePolicyEngine
|
||||
}
|
||||
|
||||
// newX509PolicyEngine creates a new x509 name policy engine
|
||||
func newX509PolicyEngine(x509Opts *X509Options) (x509PolicyEngine, error) {
|
||||
|
||||
if x509Opts == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
options := []policy.NamePolicyOption{
|
||||
policy.WithSubjectCommonNameVerification(), // enable x509 Subject Common Name validation by default
|
||||
}
|
||||
|
||||
allowed := x509Opts.GetAllowedNameOptions()
|
||||
if allowed != nil && allowed.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithPermittedDNSDomains(allowed.DNSDomains),
|
||||
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges),
|
||||
policy.WithPermittedEmailAddresses(allowed.EmailAddresses),
|
||||
policy.WithPermittedURIDomains(allowed.URIDomains),
|
||||
)
|
||||
}
|
||||
|
||||
denied := x509Opts.GetDeniedNameOptions()
|
||||
if denied != nil && denied.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithExcludedDNSDomains(denied.DNSDomains),
|
||||
policy.WithExcludedIPsOrCIDRs(denied.IPRanges),
|
||||
policy.WithExcludedEmailAddresses(denied.EmailAddresses),
|
||||
policy.WithExcludedURIDomains(denied.URIDomains),
|
||||
)
|
||||
}
|
||||
|
||||
return policy.New(options...)
|
||||
}
|
||||
|
||||
// newSSHUserPolicyEngine creates a new SSH user certificate policy engine
|
||||
func newSSHUserPolicyEngine(sshOpts *SSHOptions) (*userPolicyEngine, error) {
|
||||
policyEngine, err := newSSHPolicyEngine(sshOpts, userPolicyEngineType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// ensure we're not wrapping a nil engine
|
||||
if policyEngine == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return &userPolicyEngine{
|
||||
SSHNamePolicyEngine: policyEngine,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newSSHHostPolicyEngine create a new SSH host certificate policy engine
|
||||
func newSSHHostPolicyEngine(sshOpts *SSHOptions) (*hostPolicyEngine, error) {
|
||||
policyEngine, err := newSSHPolicyEngine(sshOpts, hostPolicyEngineType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// ensure we're not wrapping a nil engine
|
||||
if policyEngine == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return &hostPolicyEngine{
|
||||
SSHNamePolicyEngine: policyEngine,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// newSSHPolicyEngine creates a new SSH name policy engine
|
||||
func newSSHPolicyEngine(sshOpts *SSHOptions, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) {
|
||||
|
||||
if sshOpts == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
allowed *SSHNameOptions
|
||||
denied *SSHNameOptions
|
||||
)
|
||||
|
||||
// TODO: embed the type in the policy engine itself for reference?
|
||||
switch typ {
|
||||
case userPolicyEngineType:
|
||||
if sshOpts.User != nil {
|
||||
allowed = sshOpts.User.GetAllowedNameOptions()
|
||||
denied = sshOpts.User.GetDeniedNameOptions()
|
||||
}
|
||||
case hostPolicyEngineType:
|
||||
if sshOpts.Host != nil {
|
||||
allowed = sshOpts.Host.AllowedNames
|
||||
denied = sshOpts.Host.DeniedNames
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ)
|
||||
}
|
||||
|
||||
options := []policy.NamePolicyOption{}
|
||||
|
||||
if allowed != nil && allowed.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithPermittedDNSDomains(allowed.DNSDomains),
|
||||
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges),
|
||||
policy.WithPermittedEmailAddresses(allowed.EmailAddresses),
|
||||
policy.WithPermittedPrincipals(allowed.Principals),
|
||||
)
|
||||
}
|
||||
|
||||
if denied != nil && denied.HasNames() {
|
||||
options = append(options,
|
||||
policy.WithExcludedDNSDomains(denied.DNSDomains),
|
||||
policy.WithExcludedIPsOrCIDRs(denied.IPRanges),
|
||||
policy.WithExcludedEmailAddresses(denied.EmailAddresses),
|
||||
policy.WithExcludedPrincipals(denied.Principals),
|
||||
)
|
||||
}
|
||||
|
||||
// Return nil, because there's no policy to execute. This is
|
||||
// important, because the logic that determines user vs. host certs
|
||||
// are allowed depends on this fact. The two policy engines are
|
||||
// not aware of eachother, so this check is performed in the
|
||||
// SSH name validator, instead.
|
||||
if len(options) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return policy.New(options...)
|
||||
}
|
|
@ -6,7 +6,6 @@ import (
|
|||
"encoding/json"
|
||||
stderrors "errors"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -47,6 +46,7 @@ var ErrAllowTokenReuse = stderrors.New("allow token reuse")
|
|||
// Audiences stores all supported audiences by request type.
|
||||
type Audiences struct {
|
||||
Sign []string
|
||||
Renew []string
|
||||
Revoke []string
|
||||
SSHSign []string
|
||||
SSHRevoke []string
|
||||
|
@ -57,6 +57,7 @@ type Audiences struct {
|
|||
// All returns all supported audiences across all request types in one list.
|
||||
func (a Audiences) All() (auds []string) {
|
||||
auds = a.Sign
|
||||
auds = append(auds, a.Renew...)
|
||||
auds = append(auds, a.Revoke...)
|
||||
auds = append(auds, a.SSHSign...)
|
||||
auds = append(auds, a.SSHRevoke...)
|
||||
|
@ -70,6 +71,7 @@ func (a Audiences) All() (auds []string) {
|
|||
func (a Audiences) WithFragment(fragment string) Audiences {
|
||||
ret := Audiences{
|
||||
Sign: make([]string, len(a.Sign)),
|
||||
Renew: make([]string, len(a.Renew)),
|
||||
Revoke: make([]string, len(a.Revoke)),
|
||||
SSHSign: make([]string, len(a.SSHSign)),
|
||||
SSHRevoke: make([]string, len(a.SSHRevoke)),
|
||||
|
@ -83,6 +85,13 @@ func (a Audiences) WithFragment(fragment string) Audiences {
|
|||
ret.Sign[i] = s
|
||||
}
|
||||
}
|
||||
for i, s := range a.Renew {
|
||||
if u, err := url.Parse(s); err == nil {
|
||||
ret.Renew[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
|
||||
} else {
|
||||
ret.Renew[i] = s
|
||||
}
|
||||
}
|
||||
for i, s := range a.Revoke {
|
||||
if u, err := url.Parse(s); err == nil {
|
||||
ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
|
||||
|
@ -210,6 +219,12 @@ type Config struct {
|
|||
// GetIdentityFunc is a function that returns an identity that will be
|
||||
// used by the provisioner to populate certificate attributes.
|
||||
GetIdentityFunc GetIdentityFunc
|
||||
// AuthorizeRenewFunc is a function that returns nil if a given X.509
|
||||
// certificate can be renewed.
|
||||
AuthorizeRenewFunc AuthorizeRenewFunc
|
||||
// AuthorizeSSHRenewFunc is a function that returns nil if a given SSH
|
||||
// certificate can be renewed.
|
||||
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
|
||||
}
|
||||
|
||||
type provisioner struct {
|
||||
|
@ -278,32 +293,6 @@ func (l *List) UnmarshalJSON(data []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
|
||||
|
||||
// SanitizeSSHUserPrincipal grabs an email or a string with the format
|
||||
// local@domain and returns a sanitized version of the local, valid to be used
|
||||
// as a user name. If the email starts with a letter between a and z, the
|
||||
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
|
||||
func SanitizeSSHUserPrincipal(email string) string {
|
||||
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||
email = email[:i]
|
||||
}
|
||||
return strings.Map(func(r rune) rune {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
return r
|
||||
case r >= '0' && r <= '9':
|
||||
return r
|
||||
case r == '-':
|
||||
return '-'
|
||||
case r == '.': // drop dots
|
||||
return -1
|
||||
default:
|
||||
return '_'
|
||||
}
|
||||
}, strings.ToLower(email))
|
||||
}
|
||||
|
||||
type base struct{}
|
||||
|
||||
// AuthorizeSign returns an unimplemented error. Provisioners should overwrite
|
||||
|
@ -348,66 +337,12 @@ func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certif
|
|||
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
|
||||
}
|
||||
|
||||
// Identity is the type representing an externally supplied identity that is used
|
||||
// by provisioners to populate certificate fields.
|
||||
type Identity struct {
|
||||
Usernames []string `json:"usernames"`
|
||||
Permissions `json:"permissions"`
|
||||
}
|
||||
|
||||
// Permissions defines extra extensions and critical options to grant to an SSH certificate.
|
||||
type Permissions struct {
|
||||
Extensions map[string]string `json:"extensions"`
|
||||
CriticalOptions map[string]string `json:"criticalOptions"`
|
||||
}
|
||||
|
||||
// GetIdentityFunc is a function that returns an identity.
|
||||
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
|
||||
|
||||
// DefaultIdentityFunc return a default identity depending on the provisioner
|
||||
// type. For OIDC email is always present and the usernames might
|
||||
// contain empty strings.
|
||||
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
|
||||
switch k := p.(type) {
|
||||
case *OIDC:
|
||||
// OIDC principals would be:
|
||||
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
|
||||
// 2. Sanitized local.
|
||||
// 3. Raw local (if different).
|
||||
// 4. Email address.
|
||||
name := SanitizeSSHUserPrincipal(email)
|
||||
if !sshUserRegex.MatchString(name) {
|
||||
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
|
||||
}
|
||||
usernames := []string{name}
|
||||
if i := strings.LastIndex(email, "@"); i >= 0 {
|
||||
usernames = append(usernames, email[:i])
|
||||
}
|
||||
usernames = append(usernames, email)
|
||||
return &Identity{
|
||||
Usernames: SanitizeStringSlices(usernames),
|
||||
}, nil
|
||||
default:
|
||||
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeStringSlices removes duplicated an empty strings.
|
||||
func SanitizeStringSlices(original []string) []string {
|
||||
output := []string{}
|
||||
seen := make(map[string]struct{})
|
||||
for _, entry := range original {
|
||||
if entry == "" {
|
||||
continue
|
||||
}
|
||||
if _, value := seen[entry]; !value {
|
||||
seen[entry] = struct{}{}
|
||||
output = append(output, entry)
|
||||
}
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
// MockProvisioner for testing
|
||||
type MockProvisioner struct {
|
||||
Mret1, Mret2, Mret3 interface{}
|
||||
|
|
|
@ -5,7 +5,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/policy"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
)
|
||||
|
||||
// SCEP is the SCEP provisioner type, an entity that can authorize the
|
||||
|
@ -15,23 +16,25 @@ type SCEP struct {
|
|||
ID string `json:"-"`
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
|
||||
ForceCN bool `json:"forceCN,omitempty"`
|
||||
ChallengePassword string `json:"challenge,omitempty"`
|
||||
Capabilities []string `json:"capabilities,omitempty"`
|
||||
|
||||
// IncludeRoot makes the provisioner return the CA root in addition to the
|
||||
// intermediate in the GetCACerts response
|
||||
IncludeRoot bool `json:"includeRoot,omitempty"`
|
||||
|
||||
// MinimumPublicKeyLength is the minimum length for public keys in CSRs
|
||||
MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"`
|
||||
|
||||
// Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7
|
||||
// at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63
|
||||
// Defaults to 0, being DES-CBC
|
||||
EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
claimer *Claimer
|
||||
x509Policy policy.X509NamePolicyEngine
|
||||
ctl *Controller
|
||||
x509Policy policy.X509Policy
|
||||
secretChallengePassword string
|
||||
encryptionAlgorithm int
|
||||
}
|
||||
|
@ -78,7 +81,7 @@ func (s *SCEP) GetOptions() *Options {
|
|||
// DefaultTLSCertDuration returns the default TLS cert duration enforced by
|
||||
// the provisioner.
|
||||
func (s *SCEP) DefaultTLSCertDuration() time.Duration {
|
||||
return s.claimer.DefaultTLSCertDuration()
|
||||
return s.ctl.Claimer.DefaultTLSCertDuration()
|
||||
}
|
||||
|
||||
// Init initializes and validates the fields of a SCEP type.
|
||||
|
@ -90,11 +93,6 @@ func (s *SCEP) Init(config Config) (err error) {
|
|||
return errors.New("provisioner name cannot be empty")
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
if s.claimer, err = NewClaimer(s.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mask the actual challenge value, so it won't be marshaled
|
||||
s.secretChallengePassword = s.ChallengePassword
|
||||
s.ChallengePassword = "*** redacted ***"
|
||||
|
@ -116,11 +114,12 @@ func (s *SCEP) Init(config Config) (err error) {
|
|||
// TODO: add other, SCEP specific, options?
|
||||
|
||||
// Initialize the x509 allow/deny policy engine
|
||||
if s.x509Policy, err = newX509PolicyEngine(s.Options.GetX509Options()); err != nil {
|
||||
if s.x509Policy, err = policy.NewX509PolicyEngine(s.Options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
s.ctl, err = NewController(s, s.Claims, config)
|
||||
return
|
||||
}
|
||||
|
||||
// AuthorizeSign does not do any verification, because all verification is handled
|
||||
|
@ -131,10 +130,10 @@ func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
|
|||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeSCEP, s.Name, ""),
|
||||
newForceCNOption(s.ForceCN),
|
||||
profileDefaultDuration(s.claimer.DefaultTLSCertDuration()),
|
||||
profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()),
|
||||
// validators
|
||||
newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength),
|
||||
newValidityValidator(s.claimer.MinTLSCertDuration(), s.claimer.MaxTLSCertDuration()),
|
||||
newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(s.x509Policy),
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -14,10 +13,11 @@ import (
|
|||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
)
|
||||
|
||||
// DefaultCertValidity is the default validity for a certificate if none is specified.
|
||||
|
@ -407,18 +407,17 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
|
|||
// x509NamePolicyValidator validates that the certificate (to be signed)
|
||||
// contains only allowed SANs.
|
||||
type x509NamePolicyValidator struct {
|
||||
policyEngine x509PolicyEngine
|
||||
policyEngine policy.X509Policy
|
||||
}
|
||||
|
||||
// newX509NamePolicyValidator return a new SANs allow/deny validator.
|
||||
func newX509NamePolicyValidator(engine x509PolicyEngine) *x509NamePolicyValidator {
|
||||
func newX509NamePolicyValidator(engine policy.X509Policy) *x509NamePolicyValidator {
|
||||
return &x509NamePolicyValidator{
|
||||
policyEngine: engine,
|
||||
}
|
||||
}
|
||||
|
||||
// Valid validates validates that the certificate (to be signed)
|
||||
// contains only allowed SANs.
|
||||
// Valid validates that the certificate (to be signed) contains only allowed SANs.
|
||||
func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) error {
|
||||
if v.policyEngine == nil {
|
||||
return nil
|
||||
|
@ -427,17 +426,17 @@ func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) e
|
|||
return err
|
||||
}
|
||||
|
||||
var (
|
||||
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
||||
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
|
||||
)
|
||||
// var (
|
||||
// stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
|
||||
// stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
|
||||
// )
|
||||
|
||||
type stepProvisionerASN1 struct {
|
||||
Type int
|
||||
Name []byte
|
||||
CredentialID []byte
|
||||
KeyValuePairs []string `asn1:"optional,omitempty"`
|
||||
}
|
||||
// type stepProvisionerASN1 struct {
|
||||
// Type int
|
||||
// Name []byte
|
||||
// CredentialID []byte
|
||||
// KeyValuePairs []string `asn1:"optional,omitempty"`
|
||||
// }
|
||||
|
||||
type forceCNOption struct {
|
||||
ForceCN bool
|
||||
|
@ -464,23 +463,22 @@ func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error {
|
|||
}
|
||||
|
||||
type provisionerExtensionOption struct {
|
||||
Type int
|
||||
Name string
|
||||
CredentialID string
|
||||
KeyValuePairs []string
|
||||
Extension
|
||||
}
|
||||
|
||||
func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption {
|
||||
return &provisionerExtensionOption{
|
||||
Type: int(typ),
|
||||
Extension: Extension{
|
||||
Type: typ,
|
||||
Name: name,
|
||||
CredentialID: credentialID,
|
||||
KeyValuePairs: keyValuePairs,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error {
|
||||
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...)
|
||||
ext, err := o.ToExtension()
|
||||
if err != nil {
|
||||
return errs.NewError(http.StatusInternalServerError, err, "error creating certificate")
|
||||
}
|
||||
|
@ -494,20 +492,3 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption
|
|||
cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func createProvisionerExtension(typ int, name, credentialID string, keyValuePairs ...string) (pkix.Extension, error) {
|
||||
b, err := asn1.Marshal(stepProvisionerASN1{
|
||||
Type: typ,
|
||||
Name: []byte(name),
|
||||
CredentialID: []byte(credentialID),
|
||||
KeyValuePairs: keyValuePairs,
|
||||
})
|
||||
if err != nil {
|
||||
return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension")
|
||||
}
|
||||
return pkix.Extension{
|
||||
Id: stepOIDProvisioner,
|
||||
Critical: false,
|
||||
Value: b,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -636,18 +636,18 @@ func Test_newProvisionerExtension_Option(t *testing.T) {
|
|||
valid: func(cert *x509.Certificate) {
|
||||
if assert.Len(t, 1, cert.ExtraExtensions) {
|
||||
ext := cert.ExtraExtensions[0]
|
||||
assert.Equals(t, ext.Id, stepOIDProvisioner)
|
||||
assert.Equals(t, ext.Id, StepOIDProvisioner)
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"ok/prepend": func() test {
|
||||
return test{
|
||||
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: stepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}},
|
||||
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: StepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}},
|
||||
valid: func(cert *x509.Certificate) {
|
||||
if assert.Len(t, 3, cert.ExtraExtensions) {
|
||||
ext := cert.ExtraExtensions[0]
|
||||
assert.Equals(t, ext.Id, stepOIDProvisioner)
|
||||
assert.Equals(t, ext.Id, StepOIDProvisioner)
|
||||
assert.False(t, ext.Critical)
|
||||
}
|
||||
},
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/crypto/keyutil"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
@ -448,20 +449,19 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOpti
|
|||
// sshNamePolicyValidator validates that the certificate (to be signed)
|
||||
// contains only allowed principals.
|
||||
type sshNamePolicyValidator struct {
|
||||
hostPolicyEngine *hostPolicyEngine
|
||||
userPolicyEngine *userPolicyEngine
|
||||
hostPolicyEngine policy.HostPolicy
|
||||
userPolicyEngine policy.UserPolicy
|
||||
}
|
||||
|
||||
// newSSHNamePolicyValidator return a new SSH allow/deny validator.
|
||||
func newSSHNamePolicyValidator(host *hostPolicyEngine, user *userPolicyEngine) *sshNamePolicyValidator {
|
||||
func newSSHNamePolicyValidator(host policy.HostPolicy, user policy.UserPolicy) *sshNamePolicyValidator {
|
||||
return &sshNamePolicyValidator{
|
||||
hostPolicyEngine: host,
|
||||
userPolicyEngine: user,
|
||||
}
|
||||
}
|
||||
|
||||
// Valid validates validates that the certificate (to be signed)
|
||||
// contains only allowed principals.
|
||||
// Valid validates that the certificate (to be signed) contains only allowed principals.
|
||||
func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error {
|
||||
if v.hostPolicyEngine == nil && v.userPolicyEngine == nil {
|
||||
// no policy configured at all; allow anything
|
||||
|
@ -473,29 +473,25 @@ func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions)
|
|||
// the same for host certs: if only a user policy engine is configured, host
|
||||
// certs are denied. When both policy engines are configured, the type of
|
||||
// cert determines which policy engine is used.
|
||||
policyType, ok := certTypeToPolicyEngineType[cert.CertType]
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected SSH cert type %d", cert.CertType)
|
||||
}
|
||||
switch policyType {
|
||||
case hostPolicyEngineType:
|
||||
switch cert.CertType {
|
||||
case ssh.HostCert:
|
||||
// when no host policy engine is configured, but a user policy engine is
|
||||
// configured, we don't allow the host certificate.
|
||||
// configured, the host certificate is denied.
|
||||
if v.hostPolicyEngine == nil && v.userPolicyEngine != nil {
|
||||
return errors.New("SSH host certificate not authorized") // TODO: include principals in message?
|
||||
return errors.New("SSH host certificate not authorized")
|
||||
}
|
||||
_, err := v.hostPolicyEngine.ArePrincipalsAllowed(cert)
|
||||
return err
|
||||
case userPolicyEngineType:
|
||||
case ssh.UserCert:
|
||||
// when no user policy engine is configured, but a host policy engine is
|
||||
// configured, we don't allow the user certificate.
|
||||
// configured, the user certificate is denied.
|
||||
if v.userPolicyEngine == nil && v.hostPolicyEngine != nil {
|
||||
return errors.New("SSH user certificate not authorized") // TODO: include principals in message?
|
||||
return errors.New("SSH user certificate not authorized")
|
||||
}
|
||||
_, err := v.userPolicyEngine.ArePrincipalsAllowed(cert)
|
||||
return err
|
||||
default:
|
||||
return fmt.Errorf("unexpected policy engine type %q", policyType) // satisfy return; shouldn't happen
|
||||
return fmt.Errorf("unexpected SSH certificate type %d", cert.CertType) // satisfy return; shouldn't happen
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -685,7 +685,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) {
|
|||
func Test_sshCertValidityValidator(t *testing.T) {
|
||||
p, err := generateX5C(nil)
|
||||
assert.FatalError(t, err)
|
||||
v := sshCertValidityValidator{p.claimer}
|
||||
v := sshCertValidityValidator{p.ctl.Claimer}
|
||||
n := now()
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -806,7 +806,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|||
tests := map[string]func() test{
|
||||
"fail/type-not-set": func() test {
|
||||
return test{
|
||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)},
|
||||
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
|
||||
cert: &ssh.Certificate{
|
||||
ValidAfter: uint64(n.Unix()),
|
||||
ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
|
||||
|
@ -816,7 +816,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|||
},
|
||||
"fail/type-not-recognized": func() test {
|
||||
return test{
|
||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)},
|
||||
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
|
||||
cert: &ssh.Certificate{
|
||||
CertType: 4,
|
||||
ValidAfter: uint64(n.Unix()),
|
||||
|
@ -827,7 +827,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|||
},
|
||||
"fail/requested-validAfter-after-limit": func() test {
|
||||
return test{
|
||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)},
|
||||
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
|
||||
cert: &ssh.Certificate{
|
||||
CertType: 1,
|
||||
ValidAfter: uint64(n.Add(2 * time.Hour).Unix()),
|
||||
|
@ -838,7 +838,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|||
},
|
||||
"fail/requested-validBefore-after-limit": func() test {
|
||||
return test{
|
||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)},
|
||||
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
|
||||
cert: &ssh.Certificate{
|
||||
CertType: 1,
|
||||
ValidAfter: uint64(n.Unix()),
|
||||
|
@ -850,7 +850,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|||
"ok/no-limit": func() test {
|
||||
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
|
||||
return test{
|
||||
svm: &sshLimitDuration{Claimer: p.claimer},
|
||||
svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
|
||||
cert: &ssh.Certificate{
|
||||
CertType: 1,
|
||||
},
|
||||
|
@ -863,7 +863,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|||
"ok/defaults": func() test {
|
||||
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
|
||||
return test{
|
||||
svm: &sshLimitDuration{Claimer: p.claimer},
|
||||
svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
|
||||
cert: &ssh.Certificate{
|
||||
CertType: 1,
|
||||
},
|
||||
|
@ -876,7 +876,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|||
"ok/valid-requested-validBefore": func() test {
|
||||
va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix())
|
||||
return test{
|
||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)},
|
||||
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
|
||||
cert: &ssh.Certificate{
|
||||
CertType: 1,
|
||||
ValidAfter: va,
|
||||
|
@ -891,7 +891,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|||
"ok/empty-requested-validBefore-limit-after-default": func() test {
|
||||
va := uint64(n.Unix())
|
||||
return test{
|
||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)},
|
||||
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)},
|
||||
cert: &ssh.Certificate{
|
||||
CertType: 1,
|
||||
ValidAfter: va,
|
||||
|
@ -905,7 +905,7 @@ func Test_sshValidityModifier(t *testing.T) {
|
|||
"ok/empty-requested-validBefore-limit-before-default": func() test {
|
||||
va := uint64(n.Unix())
|
||||
return test{
|
||||
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)},
|
||||
svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
|
||||
cert: &ssh.Certificate{
|
||||
CertType: 1,
|
||||
ValidAfter: va,
|
||||
|
|
|
@ -6,6 +6,8 @@ import (
|
|||
|
||||
"github.com/pkg/errors"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
)
|
||||
|
||||
// SSHCertificateOptions is an interface that returns a list of options passed when
|
||||
|
@ -35,32 +37,58 @@ type SSHOptions struct {
|
|||
TemplateData json.RawMessage `json:"templateData,omitempty"`
|
||||
|
||||
// User contains SSH user certificate options.
|
||||
User *SSHUserCertificateOptions `json:"user,omitempty"`
|
||||
User *policy.SSHUserCertificateOptions
|
||||
|
||||
// Host contains SSH host certificate options.
|
||||
Host *SSHHostCertificateOptions `json:"host,omitempty"`
|
||||
Host *policy.SSHHostCertificateOptions
|
||||
}
|
||||
|
||||
// SSHUserCertificateOptions is a collection of SSH user certificate options.
|
||||
type SSHUserCertificateOptions struct {
|
||||
// AllowedNames contains the names the provisioner is authorized to sign
|
||||
AllowedNames *SSHNameOptions `json:"allow,omitempty"`
|
||||
|
||||
// DeniedNames contains the names the provisioner is not authorized to sign
|
||||
DeniedNames *SSHNameOptions `json:"deny,omitempty"`
|
||||
// GetAllowedUserNameOptions returns the SSHNameOptions that are
|
||||
// allowed when SSH User certificates are requested.
|
||||
func (o *SSHOptions) GetAllowedUserNameOptions() *policy.SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
if o.User == nil {
|
||||
return nil
|
||||
}
|
||||
return o.User.AllowedNames
|
||||
}
|
||||
|
||||
// SSHHostCertificateOptions is a collection of SSH host certificate options.
|
||||
// It's an alias of SSHUserCertificateOptions, as the options are the same
|
||||
// for both types of certificates.
|
||||
type SSHHostCertificateOptions SSHUserCertificateOptions
|
||||
// GetDeniedUserNameOptions returns the SSHNameOptions that are
|
||||
// denied when SSH user certificates are requested.
|
||||
func (o *SSHOptions) GetDeniedUserNameOptions() *policy.SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
if o.User == nil {
|
||||
return nil
|
||||
}
|
||||
return o.User.DeniedNames
|
||||
}
|
||||
|
||||
// SSHNameOptions models the SSH name policy configuration.
|
||||
type SSHNameOptions struct {
|
||||
DNSDomains []string `json:"dns,omitempty"`
|
||||
IPRanges []string `json:"ip,omitempty"`
|
||||
EmailAddresses []string `json:"email,omitempty"`
|
||||
Principals []string `json:"principal,omitempty"`
|
||||
// GetAllowedHostNameOptions returns the SSHNameOptions that are
|
||||
// allowed when SSH host certificates are requested.
|
||||
func (o *SSHOptions) GetAllowedHostNameOptions() *policy.SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
if o.Host == nil {
|
||||
return nil
|
||||
}
|
||||
return o.Host.AllowedNames
|
||||
}
|
||||
|
||||
// GetDeniedHostNameOptions returns the SSHNameOptions that are
|
||||
// denied when SSH host certificates are requested.
|
||||
func (o *SSHOptions) GetDeniedHostNameOptions() *policy.SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
if o.Host == nil {
|
||||
return nil
|
||||
}
|
||||
return o.Host.DeniedNames
|
||||
}
|
||||
|
||||
// HasTemplate returns true if a template is defined in the provisioner options.
|
||||
|
@ -68,32 +96,6 @@ func (o *SSHOptions) HasTemplate() bool {
|
|||
return o != nil && (o.Template != "" || o.TemplateFile != "")
|
||||
}
|
||||
|
||||
// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the
|
||||
// names that a provisioner is authorized to sign SSH certificates for.
|
||||
func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.AllowedNames
|
||||
}
|
||||
|
||||
// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the
|
||||
// names that a provisioner is NOT authorized to sign SSH certificates for.
|
||||
func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions {
|
||||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
return o.DeniedNames
|
||||
}
|
||||
|
||||
// HasNames checks if the SSHNameOptions has one or more
|
||||
// names configured.
|
||||
func (o *SSHNameOptions) HasNames() bool {
|
||||
return len(o.DNSDomains) > 0 ||
|
||||
len(o.EmailAddresses) > 0 ||
|
||||
len(o.Principals) > 0
|
||||
}
|
||||
|
||||
// TemplateSSHOptions generates a SSHCertificateOptions with the template and
|
||||
// data defined in the ProvisionerOptions, the provisioner generated data, and
|
||||
// the user data provided in the request. If no template has been provided,
|
||||
|
|
|
@ -29,8 +29,7 @@ type SSHPOP struct {
|
|||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
claimer *Claimer
|
||||
audiences Audiences
|
||||
ctl *Controller
|
||||
sshPubKeys *SSHKeys
|
||||
}
|
||||
|
||||
|
@ -83,7 +82,7 @@ func (p *SSHPOP) GetEncryptedKey() (string, string, bool) {
|
|||
}
|
||||
|
||||
// Init initializes and validates the fields of a SSHPOP type.
|
||||
func (p *SSHPOP) Init(config Config) error {
|
||||
func (p *SSHPOP) Init(config Config) (err error) {
|
||||
switch {
|
||||
case p.Type == "":
|
||||
return errors.New("provisioner type cannot be empty")
|
||||
|
@ -93,17 +92,15 @@ func (p *SSHPOP) Init(config Config) error {
|
|||
return errors.New("provisioner public SSH validation keys cannot be empty")
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
var err error
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(hs): initialize the policy engine and add it as an SSH cert validator
|
||||
|
||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
p.sshPubKeys = config.SSHKeys
|
||||
return nil
|
||||
|
||||
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
|
||||
// Update claims with global ones
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
return
|
||||
}
|
||||
|
||||
// authorizeToken performs common jwt authorization actions and returns the
|
||||
|
@ -111,7 +108,7 @@ func (p *SSHPOP) Init(config Config) error {
|
|||
// e.g. a Sign request will auth/validate different fields than a Revoke request.
|
||||
//
|
||||
// Checking for certificate revocation has been moved to the authority package.
|
||||
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) {
|
||||
func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity bool) (*sshPOPPayload, error) {
|
||||
sshCert, jwt, err := ExtractSSHPOPCert(token)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusUnauthorized, err,
|
||||
|
@ -119,13 +116,18 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
|||
}
|
||||
|
||||
// Check validity period of the certificate.
|
||||
n := time.Now()
|
||||
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) {
|
||||
//
|
||||
// Controller.AuthorizeSSHRenew will validate this on the renewal flow.
|
||||
if checkValidity {
|
||||
unixNow := time.Now().Unix()
|
||||
if after := int64(sshCert.ValidAfter); after < 0 || unixNow < int64(sshCert.ValidAfter) {
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
|
||||
}
|
||||
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) {
|
||||
if before := int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
|
||||
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
|
||||
}
|
||||
}
|
||||
|
||||
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
|
||||
if !ok {
|
||||
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
|
||||
|
@ -188,7 +190,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
|
|||
// AuthorizeSSHRevoke validates the authorization token and extracts/validates
|
||||
// the SSH certificate from the ssh-pop header.
|
||||
func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHRevoke)
|
||||
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true)
|
||||
if err != nil {
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
|
||||
}
|
||||
|
@ -201,22 +203,20 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
|
|||
// AuthorizeSSHRenew validates the authorization token and extracts/validates
|
||||
// the SSH certificate from the ssh-pop header.
|
||||
func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHRenew)
|
||||
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew, false)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
|
||||
}
|
||||
if claims.sshCert.CertType != ssh.HostCert {
|
||||
return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
|
||||
}
|
||||
|
||||
return claims.sshCert, nil
|
||||
|
||||
return claims.sshCert, p.ctl.AuthorizeSSHRenew(ctx, claims.sshCert)
|
||||
}
|
||||
|
||||
// AuthorizeSSHRekey validates the authorization token and extracts/validates
|
||||
// the SSH certificate from the ssh-pop header.
|
||||
func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHRekey)
|
||||
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true)
|
||||
if err != nil {
|
||||
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
|
||||
}
|
||||
|
@ -227,11 +227,10 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
|
|||
// Validate public key
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require and validate all the default fields in the SSH certificate.
|
||||
&sshCertDefaultValidator{},
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate
|
||||
|
|
|
@ -38,6 +38,7 @@ func TestSSHPOP_Getters(t *testing.T) {
|
|||
}
|
||||
|
||||
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
|
||||
now := time.Now()
|
||||
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
@ -46,6 +47,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if cert.ValidAfter == 0 {
|
||||
cert.ValidAfter = uint64(now.Unix())
|
||||
}
|
||||
if cert.ValidBefore == 0 {
|
||||
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
|
||||
}
|
||||
if err := cert.SignCert(rand.Reader, signer); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -207,7 +214,7 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
|
|||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
|
||||
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tc.code)
|
||||
|
@ -455,7 +462,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
|
|||
case *sshDefaultPublicKeyValidator:
|
||||
case *sshCertDefaultValidator:
|
||||
case *sshCertValidityValidator:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||
default:
|
||||
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
|
||||
}
|
||||
|
|
21
authority/provisioner/testdata/certs/bad-extension.crt
vendored
Normal file
21
authority/provisioner/testdata/certs/bad-extension.crt
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
-----BEGIN CERTIFICATE-----
|
||||
MIIDeTCCAx+gAwIBAgIRAOTItW2pYuSU+PkmLW090iUwCgYIKoZIzj0EAwIwJDEi
|
||||
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjUy
|
||||
MjBaFw0yMjAzMTIyMjUzMjBaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
|
||||
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
|
||||
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
|
||||
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
|
||||
KoZIzj0DAQcDQgAE/9vvOZ1Zzysnf3VeGyotMJEMZdAborB36Ah5QL/3yQNMRWIc
|
||||
pv9Dwx19pHw7SquVE8jIaPPJSjaeWnfMPDYDxaOCAbcwggGzMA4GA1UdDwEB/wQE
|
||||
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
|
||||
ADAdBgNVHQ4EFgQUkJUg6AsqWlqTZt6BHidRMwh1vKYwHwYDVR0jBBgwFoAUDpTg
|
||||
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
|
||||
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
|
||||
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
|
||||
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
|
||||
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
|
||||
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
|
||||
LXNlcnZlci1nMy5jcmwwFwYMKwYBBAGCpGTGKEABBAdmb29vYmFyMAoGCCqGSM49
|
||||
BAMCA0gAMEUCIQCWYqOuk4bLkVVeHvo3P8TlJJ3fw6ijDDLstvdrQqAl5wIgEjSY
|
||||
wVcR649Oc8PJGh/43Kpx0+4OTYPQrD/JqphVF7g=
|
||||
-----END CERTIFICATE-----
|
22
authority/provisioner/testdata/certs/good-extension.crt
vendored
Normal file
22
authority/provisioner/testdata/certs/good-extension.crt
vendored
Normal file
|
@ -0,0 +1,22 @@
|
|||
-----BEGIN CERTIFICATE-----
|
||||
MIIDujCCA2GgAwIBAgIRAM5celDKTTqAGycljO7FZdEwCgYIKoZIzj0EAwIwJDEi
|
||||
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjQx
|
||||
MDRaFw0yMjAzMTIyMjQyMDRaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
|
||||
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
|
||||
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
|
||||
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
|
||||
KoZIzj0DAQcDQgAEkXffZYlSJRMxJrZHmUpEMC4jQYCkF86mLJY0iLZ8k00N/xF0
|
||||
4rAGwzTU/l9tfRpNl+z/XfMMWPXS0Q8NU/o4S6OCAfkwggH1MA4GA1UdDwEB/wQE
|
||||
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
|
||||
ADAdBgNVHQ4EFgQUL3sSlYW8Tf2l2P+gFTdn5wsUjfgwHwYDVR0jBBgwFoAUDpTg
|
||||
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
|
||||
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
|
||||
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
|
||||
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
|
||||
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
|
||||
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
|
||||
LXNlcnZlci1nMy5jcmwwWQYMKwYBBAGCpGTGKEABBEkwRwIBAQQVbWFyaWFub0Bz
|
||||
bWFsbHN0ZXAuY29tBCtudmduUjh3U3pwVWxydF90QzNtdnJod2hCeDlZN1QxV0xf
|
||||
SmpjRlZXWUJRMAoGCCqGSM49BAMCA0cAMEQCIE6umrhSbeQWWVK5cWBvXj5c0cGB
|
||||
bUF0rNw/dsaCaWcwAiAKSkmjhsC63DVPXPCNUki90YgVovO69foO1ZaB43lx5w==
|
||||
-----END CERTIFICATE-----
|
|
@ -25,12 +25,12 @@ import (
|
|||
|
||||
var (
|
||||
defaultDisableRenewal = false
|
||||
defaultAllowRenewAfterExpiry = false
|
||||
defaultEnableSSHCA = true
|
||||
globalProvisionerClaims = Claims{
|
||||
MinTLSDur: &Duration{5 * time.Minute},
|
||||
MaxTLSDur: &Duration{24 * time.Hour},
|
||||
DefaultTLSDur: &Duration{24 * time.Hour},
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
|
||||
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
|
||||
DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour},
|
||||
|
@ -38,6 +38,8 @@ var (
|
|||
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
||||
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
|
||||
EnableSSHCA: &defaultEnableSSHCA,
|
||||
DisableRenewal: &defaultDisableRenewal,
|
||||
AllowRenewAfterExpiry: &defaultAllowRenewAfterExpiry,
|
||||
}
|
||||
testAudiences = Audiences{
|
||||
Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"},
|
||||
|
@ -172,19 +174,18 @@ func generateJWK() (*JWK, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &JWK{
|
||||
|
||||
p := &JWK{
|
||||
Name: name,
|
||||
Type: "JWK",
|
||||
Key: &public,
|
||||
EncryptedKey: encrypted,
|
||||
Claims: &globalProvisionerClaims,
|
||||
audiences: testAudiences,
|
||||
claimer: claimer,
|
||||
}, nil
|
||||
}
|
||||
p.ctl, err = NewController(p, p.Claims, Config{
|
||||
Audiences: testAudiences,
|
||||
})
|
||||
return p, err
|
||||
}
|
||||
|
||||
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
|
||||
|
@ -205,23 +206,21 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pubKeys := []interface{}{fooPub, barPub}
|
||||
if inputPubKey != nil {
|
||||
pubKeys = append(pubKeys, inputPubKey)
|
||||
}
|
||||
|
||||
return &K8sSA{
|
||||
p := &K8sSA{
|
||||
Name: K8sSAName,
|
||||
Type: "K8sSA",
|
||||
Claims: &globalProvisionerClaims,
|
||||
audiences: testAudiences,
|
||||
claimer: claimer,
|
||||
pubKeys: pubKeys,
|
||||
}, nil
|
||||
}
|
||||
p.ctl, err = NewController(p, p.Claims, Config{
|
||||
Audiences: testAudiences,
|
||||
})
|
||||
return p, err
|
||||
}
|
||||
|
||||
func generateSSHPOP() (*SSHPOP, error) {
|
||||
|
@ -229,11 +228,6 @@ func generateSSHPOP() (*SSHPOP, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -251,17 +245,19 @@ func generateSSHPOP() (*SSHPOP, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &SSHPOP{
|
||||
p := &SSHPOP{
|
||||
Name: name,
|
||||
Type: "SSHPOP",
|
||||
Claims: &globalProvisionerClaims,
|
||||
audiences: testAudiences,
|
||||
claimer: claimer,
|
||||
sshPubKeys: &SSHKeys{
|
||||
UserKeys: []ssh.PublicKey{userKey},
|
||||
HostKeys: []ssh.PublicKey{hostKey},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
p.ctl, err = NewController(p, p.Claims, Config{
|
||||
Audiences: testAudiences,
|
||||
})
|
||||
return p, err
|
||||
}
|
||||
|
||||
func generateX5C(root []byte) (*X5C, error) {
|
||||
|
@ -283,11 +279,6 @@ M46l92gdOozT
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rootPool := x509.NewCertPool()
|
||||
|
||||
var (
|
||||
|
@ -305,15 +296,17 @@ M46l92gdOozT
|
|||
}
|
||||
rootPool.AddCert(cert)
|
||||
}
|
||||
return &X5C{
|
||||
p := &X5C{
|
||||
Name: name,
|
||||
Type: "X5C",
|
||||
Roots: root,
|
||||
Claims: &globalProvisionerClaims,
|
||||
audiences: testAudiences,
|
||||
claimer: claimer,
|
||||
rootPool: rootPool,
|
||||
}, nil
|
||||
}
|
||||
p.ctl, err = NewController(p, p.Claims, Config{
|
||||
Audiences: testAudiences,
|
||||
})
|
||||
return p, err
|
||||
}
|
||||
|
||||
func generateOIDC() (*OIDC, error) {
|
||||
|
@ -333,11 +326,7 @@ func generateOIDC() (*OIDC, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OIDC{
|
||||
p := &OIDC{
|
||||
Name: name,
|
||||
Type: "OIDC",
|
||||
ClientID: clientID,
|
||||
|
@ -351,8 +340,11 @@ func generateOIDC() (*OIDC, error) {
|
|||
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||
expiry: time.Now().Add(24 * time.Hour),
|
||||
},
|
||||
claimer: claimer,
|
||||
}, nil
|
||||
}
|
||||
p.ctl, err = NewController(p, p.Claims, Config{
|
||||
Audiences: testAudiences,
|
||||
})
|
||||
return p, err
|
||||
}
|
||||
|
||||
func generateGCP() (*GCP, error) {
|
||||
|
@ -368,23 +360,21 @@ func generateGCP() (*GCP, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &GCP{
|
||||
p := &GCP{
|
||||
Type: "GCP",
|
||||
Name: name,
|
||||
ServiceAccounts: []string{serviceAccount},
|
||||
Claims: &globalProvisionerClaims,
|
||||
claimer: claimer,
|
||||
config: newGCPConfig(),
|
||||
keyStore: &keyStore{
|
||||
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||
expiry: time.Now().Add(24 * time.Hour),
|
||||
},
|
||||
audiences: testAudiences.WithFragment("gcp/" + name),
|
||||
}, nil
|
||||
}
|
||||
p.ctl, err = NewController(p, p.Claims, Config{
|
||||
Audiences: testAudiences.WithFragment("gcp/" + name),
|
||||
})
|
||||
return p, err
|
||||
}
|
||||
|
||||
func generateAWS() (*AWS, error) {
|
||||
|
@ -396,10 +386,6 @@ func generateAWS() (*AWS, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
block, _ := pem.Decode([]byte(awsTestCertificate))
|
||||
if block == nil || block.Type != "CERTIFICATE" {
|
||||
return nil, errors.New("error decoding AWS certificate")
|
||||
|
@ -408,13 +394,12 @@ func generateAWS() (*AWS, error) {
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing AWS certificate")
|
||||
}
|
||||
return &AWS{
|
||||
p := &AWS{
|
||||
Type: "AWS",
|
||||
Name: name,
|
||||
Accounts: []string{accountID},
|
||||
Claims: &globalProvisionerClaims,
|
||||
IMDSVersions: []string{"v2", "v1"},
|
||||
claimer: claimer,
|
||||
config: &awsConfig{
|
||||
identityURL: awsIdentityURL,
|
||||
signatureURL: awsSignatureURL,
|
||||
|
@ -423,8 +408,11 @@ func generateAWS() (*AWS, error) {
|
|||
certificates: []*x509.Certificate{cert},
|
||||
signatureAlgorithm: awsSignatureAlgorithm,
|
||||
},
|
||||
audiences: testAudiences.WithFragment("aws/" + name),
|
||||
}, nil
|
||||
}
|
||||
p.ctl, err = NewController(p, p.Claims, Config{
|
||||
Audiences: testAudiences.WithFragment("aws/" + name),
|
||||
})
|
||||
return p, err
|
||||
}
|
||||
|
||||
func generateAWSWithServer() (*AWS, *httptest.Server, error) {
|
||||
|
@ -505,10 +493,6 @@ func generateAWSV1Only() (*AWS, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
block, _ := pem.Decode([]byte(awsTestCertificate))
|
||||
if block == nil || block.Type != "CERTIFICATE" {
|
||||
return nil, errors.New("error decoding AWS certificate")
|
||||
|
@ -517,13 +501,12 @@ func generateAWSV1Only() (*AWS, error) {
|
|||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error parsing AWS certificate")
|
||||
}
|
||||
return &AWS{
|
||||
p := &AWS{
|
||||
Type: "AWS",
|
||||
Name: name,
|
||||
Accounts: []string{accountID},
|
||||
Claims: &globalProvisionerClaims,
|
||||
IMDSVersions: []string{"v1"},
|
||||
claimer: claimer,
|
||||
config: &awsConfig{
|
||||
identityURL: awsIdentityURL,
|
||||
signatureURL: awsSignatureURL,
|
||||
|
@ -532,8 +515,11 @@ func generateAWSV1Only() (*AWS, error) {
|
|||
certificates: []*x509.Certificate{cert},
|
||||
signatureAlgorithm: awsSignatureAlgorithm,
|
||||
},
|
||||
audiences: testAudiences.WithFragment("aws/" + name),
|
||||
}, nil
|
||||
}
|
||||
p.ctl, err = NewController(p, p.Claims, Config{
|
||||
Audiences: testAudiences.WithFragment("aws/" + name),
|
||||
})
|
||||
return p, err
|
||||
}
|
||||
|
||||
func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) {
|
||||
|
@ -600,21 +586,16 @@ func generateAzure() (*Azure, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
claimer, err := NewClaimer(nil, globalProvisionerClaims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwk, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Azure{
|
||||
p := &Azure{
|
||||
Type: "Azure",
|
||||
Name: name,
|
||||
TenantID: tenantID,
|
||||
Audience: azureDefaultAudience,
|
||||
Claims: &globalProvisionerClaims,
|
||||
claimer: claimer,
|
||||
config: newAzureConfig(tenantID),
|
||||
oidcConfig: openIDConfiguration{
|
||||
Issuer: "https://sts.windows.net/" + tenantID + "/",
|
||||
|
@ -624,7 +605,11 @@ func generateAzure() (*Azure, error) {
|
|||
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
|
||||
expiry: time.Now().Add(24 * time.Hour),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
p.ctl, err = NewController(p, p.Claims, Config{
|
||||
Audiences: testAudiences,
|
||||
})
|
||||
return p, err
|
||||
}
|
||||
|
||||
func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
||||
|
@ -671,7 +656,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
|
|||
w.Header().Add("Cache-Control", "max-age=5")
|
||||
writeJSON(w, getPublic(az.keyStore.keySet))
|
||||
case "/metadata/identity/oauth2/token":
|
||||
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0])
|
||||
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &az.keyStore.keySet.Keys[0])
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
} else {
|
||||
|
@ -1009,7 +994,7 @@ func generateAWSToken(p *AWS, sub, iss, aud, accountID, instanceID, privateIP, r
|
|||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
}
|
||||
|
||||
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
|
||||
sig, err := jose.NewSigner(
|
||||
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
|
||||
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
|
||||
|
@ -1017,6 +1002,12 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var xmsMirID string
|
||||
if resourceType == "vm" {
|
||||
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, resourceName)
|
||||
} else if resourceType == "uai" {
|
||||
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscriptionID, resourceGroup, resourceName)
|
||||
}
|
||||
|
||||
claims := azurePayload{
|
||||
Claims: jose.Claims{
|
||||
|
@ -1034,7 +1025,7 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
|
|||
ObjectID: "the-oid",
|
||||
TenantID: tenantID,
|
||||
Version: "the-version",
|
||||
XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine),
|
||||
XMSMirID: xmsMirID,
|
||||
}
|
||||
return jose.Signed(sig).Claims(claims).CompactSerialize()
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/crypto/jose"
|
||||
"go.step.sm/crypto/sshutil"
|
||||
|
@ -32,12 +33,11 @@ type X5C struct {
|
|||
Roots []byte `json:"roots"`
|
||||
Claims *Claims `json:"claims,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
claimer *Claimer
|
||||
audiences Audiences
|
||||
ctl *Controller
|
||||
rootPool *x509.CertPool
|
||||
x509Policy x509PolicyEngine
|
||||
sshHostPolicy *hostPolicyEngine
|
||||
sshUserPolicy *userPolicyEngine
|
||||
x509Policy policy.X509Policy
|
||||
sshHostPolicy policy.HostPolicy
|
||||
sshUserPolicy policy.UserPolicy
|
||||
}
|
||||
|
||||
// GetID returns the provisioner unique identifier. The name and credential id
|
||||
|
@ -89,7 +89,7 @@ func (p *X5C) GetEncryptedKey() (string, string, bool) {
|
|||
}
|
||||
|
||||
// Init initializes and validates the fields of a X5C type.
|
||||
func (p *X5C) Init(config Config) error {
|
||||
func (p *X5C) Init(config Config) (err error) {
|
||||
switch {
|
||||
case p.Type == "":
|
||||
return errors.New("provisioner type cannot be empty")
|
||||
|
@ -104,6 +104,7 @@ func (p *X5C) Init(config Config) error {
|
|||
var (
|
||||
block *pem.Block
|
||||
rest = p.Roots
|
||||
count int
|
||||
)
|
||||
for rest != nil {
|
||||
block, rest = pem.Decode(rest)
|
||||
|
@ -114,37 +115,33 @@ func (p *X5C) Init(config Config) error {
|
|||
if err != nil {
|
||||
return errors.Wrap(err, "error parsing x509 certificate from PEM block")
|
||||
}
|
||||
count++
|
||||
p.rootPool.AddCert(cert)
|
||||
}
|
||||
|
||||
// Verify that at least one root was found.
|
||||
if len(p.rootPool.Subjects()) == 0 {
|
||||
if count == 0 {
|
||||
return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName())
|
||||
}
|
||||
|
||||
// Update claims with global ones
|
||||
var err error
|
||||
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the x509 allow/deny policy engine
|
||||
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for user certificates
|
||||
if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize the SSH allow/deny policy engine for host certificates
|
||||
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
return nil
|
||||
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
|
||||
p.ctl, err = NewController(p, p.Claims, config)
|
||||
return
|
||||
}
|
||||
|
||||
// authorizeToken performs common jwt authorization actions and returns the
|
||||
|
@ -207,13 +204,13 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
|
|||
// AuthorizeRevoke returns an error if the provisioner does not have rights to
|
||||
// revoke the certificate with serial number in the `sub` property.
|
||||
func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error {
|
||||
_, err := p.authorizeToken(token, p.audiences.Revoke)
|
||||
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
|
||||
return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke")
|
||||
}
|
||||
|
||||
// AuthorizeSign validates the given token.
|
||||
func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
claims, err := p.authorizeToken(token, p.audiences.Sign)
|
||||
claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
|
||||
}
|
||||
|
@ -245,32 +242,31 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
|
|||
templateOptions,
|
||||
// modifiers / withOptions
|
||||
newProvisionerExtensionOption(TypeX5C, p.Name, ""),
|
||||
profileLimitDuration{p.claimer.DefaultTLSCertDuration(),
|
||||
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter},
|
||||
profileLimitDuration{
|
||||
p.ctl.Claimer.DefaultTLSCertDuration(),
|
||||
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter,
|
||||
},
|
||||
// validators
|
||||
commonNameValidator(claims.Subject),
|
||||
defaultSANsValidator(claims.SANs),
|
||||
defaultPublicKeyValidator{},
|
||||
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()),
|
||||
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
|
||||
newX509NamePolicyValidator(p.x509Policy),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeRenew returns an error if the renewal is disabled.
|
||||
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
|
||||
if p.claimer.IsDisableRenewal() {
|
||||
return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName())
|
||||
}
|
||||
return nil
|
||||
return p.ctl.AuthorizeRenew(ctx, cert)
|
||||
}
|
||||
|
||||
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
|
||||
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
|
||||
if !p.claimer.IsSSHCAEnabled() {
|
||||
if !p.ctl.Claimer.IsSSHCAEnabled() {
|
||||
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName())
|
||||
}
|
||||
|
||||
claims, err := p.authorizeToken(token, p.audiences.SSHSign)
|
||||
claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
|
||||
if err != nil {
|
||||
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
|
||||
}
|
||||
|
@ -333,11 +329,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
|
|||
|
||||
return append(signOptions,
|
||||
// Checks the validity bounds, and set the validity if has not been set.
|
||||
&sshLimitDuration{p.claimer, claims.chains[0][0].NotAfter},
|
||||
&sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter},
|
||||
// Validate public key.
|
||||
&sshDefaultPublicKeyValidator{},
|
||||
// Validate the validity period.
|
||||
&sshCertValidityValidator{p.claimer},
|
||||
&sshCertValidityValidator{p.ctl.Claimer},
|
||||
// Require all the fields in the SSH certificate
|
||||
&sshCertDefaultValidator{},
|
||||
// Ensure that all principal names are allowed
|
||||
|
|
|
@ -2,6 +2,7 @@ package provisioner
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -69,7 +70,7 @@ func TestX5C_Init(t *testing.T) {
|
|||
},
|
||||
"fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest {
|
||||
return ProvisionerValidateTest{
|
||||
p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo"), audiences: testAudiences},
|
||||
p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")},
|
||||
err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"),
|
||||
}
|
||||
},
|
||||
|
@ -117,6 +118,8 @@ M46l92gdOozT
|
|||
return ProvisionerValidateTest{
|
||||
p: p,
|
||||
extraValid: func(p *X5C) error {
|
||||
// nolint:staticcheck // We don't have a different way to
|
||||
// check the number of certificates in the pool.
|
||||
numCerts := len(p.rootPool.Subjects())
|
||||
if numCerts != 2 {
|
||||
return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts)
|
||||
|
@ -141,7 +144,7 @@ M46l92gdOozT
|
|||
}
|
||||
} else {
|
||||
if assert.Nil(t, tc.err) {
|
||||
assert.Equals(t, tc.p.audiences, config.Audiences.WithFragment(tc.p.GetID()))
|
||||
assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID()))
|
||||
if tc.extraValid != nil {
|
||||
assert.Nil(t, tc.extraValid(tc.p))
|
||||
}
|
||||
|
@ -468,13 +471,13 @@ func TestX5C_AuthorizeSign(t *testing.T) {
|
|||
switch v := o.(type) {
|
||||
case certificateOptionsFunc:
|
||||
case *provisionerExtensionOption:
|
||||
assert.Equals(t, v.Type, int(TypeX5C))
|
||||
assert.Equals(t, v.Type, TypeX5C)
|
||||
assert.Equals(t, v.Name, tc.p.GetName())
|
||||
assert.Equals(t, v.CredentialID, "")
|
||||
assert.Len(t, 0, v.KeyValuePairs)
|
||||
case profileLimitDuration:
|
||||
assert.Equals(t, v.def, tc.p.claimer.DefaultTLSCertDuration())
|
||||
claims, err := tc.p.authorizeToken(tc.token, tc.p.audiences.Sign)
|
||||
assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration())
|
||||
claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign)
|
||||
assert.FatalError(t, err)
|
||||
assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter)
|
||||
case commonNameValidator:
|
||||
|
@ -483,8 +486,8 @@ func TestX5C_AuthorizeSign(t *testing.T) {
|
|||
case defaultSANsValidator:
|
||||
assert.Equals(t, []string(v), tc.sans)
|
||||
case *validityValidator:
|
||||
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration())
|
||||
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
|
||||
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
|
||||
case *x509NamePolicyValidator:
|
||||
assert.Equals(t, nil, v.policyEngine)
|
||||
default:
|
||||
|
@ -552,6 +555,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestX5C_AuthorizeRenew(t *testing.T) {
|
||||
now := time.Now().Truncate(time.Second)
|
||||
type test struct {
|
||||
p *X5C
|
||||
code int
|
||||
|
@ -564,12 +568,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
|
|||
// disable renewal
|
||||
disable := true
|
||||
p.Claims = &Claims{DisableRenewal: &disable}
|
||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
code: http.StatusUnauthorized,
|
||||
err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()),
|
||||
err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
|
||||
}
|
||||
},
|
||||
"ok": func(t *testing.T) test {
|
||||
|
@ -583,7 +587,10 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
|
|||
for name, tt := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
tc := tt(t)
|
||||
if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil {
|
||||
if err := tc.p.AuthorizeRenew(context.Background(), &x509.Certificate{
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(time.Hour),
|
||||
}); err != nil {
|
||||
if assert.NotNil(t, tc.err) {
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
|
@ -619,7 +626,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
|||
// disable sshCA
|
||||
enable := false
|
||||
p.Claims = &Claims{EnableSSHCA: &enable}
|
||||
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
|
||||
assert.FatalError(t, err)
|
||||
return test{
|
||||
p: p,
|
||||
|
@ -775,10 +782,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
|
|||
case sshCertDefaultsModifier:
|
||||
assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert})
|
||||
case *sshLimitDuration:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
|
||||
case *sshCertValidityValidator:
|
||||
assert.Equals(t, v.Claimer, tc.p.claimer)
|
||||
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
|
||||
case *sshNamePolicyValidator:
|
||||
assert.Equals(t, nil, v.userPolicyEngine)
|
||||
assert.Equals(t, nil, v.hostPolicyEngine)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/admin"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/policy"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/cli-utils/step"
|
||||
|
@ -109,6 +110,8 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner.
|
|||
HostKeys: sshKeys.HostKeys,
|
||||
},
|
||||
GetIdentityFunc: a.getIdentityFunc,
|
||||
AuthorizeRenewFunc: a.authorizeRenewFunc,
|
||||
AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
@ -395,6 +398,58 @@ func optionsToCertificates(p *linkedca.Provisioner) *provisioner.Options {
|
|||
ops.SSH.Template = string(p.SshTemplate.Template)
|
||||
ops.SSH.TemplateData = p.SshTemplate.Data
|
||||
}
|
||||
if p.Policy != nil {
|
||||
if p.Policy.X509 != nil {
|
||||
if p.Policy.X509.Allow != nil {
|
||||
ops.X509.AllowedNames = &policy.X509NameOptions{
|
||||
DNSDomains: p.Policy.X509.Allow.Dns,
|
||||
IPRanges: p.Policy.X509.Allow.Ips,
|
||||
EmailAddresses: p.Policy.X509.Allow.Emails,
|
||||
URIDomains: p.Policy.X509.Allow.Uris,
|
||||
}
|
||||
}
|
||||
if p.Policy.X509.Deny != nil {
|
||||
ops.X509.DeniedNames = &policy.X509NameOptions{
|
||||
DNSDomains: p.Policy.X509.Deny.Dns,
|
||||
IPRanges: p.Policy.X509.Deny.Ips,
|
||||
EmailAddresses: p.Policy.X509.Deny.Emails,
|
||||
URIDomains: p.Policy.X509.Deny.Uris,
|
||||
}
|
||||
}
|
||||
}
|
||||
if p.Policy.Ssh != nil {
|
||||
if p.Policy.Ssh.Host != nil {
|
||||
ops.SSH.Host = &policy.SSHHostCertificateOptions{}
|
||||
if p.Policy.Ssh.Host.Allow != nil {
|
||||
ops.SSH.Host.AllowedNames = &policy.SSHNameOptions{
|
||||
DNSDomains: p.Policy.Ssh.Host.Allow.Dns,
|
||||
IPRanges: p.Policy.Ssh.Host.Allow.Ips,
|
||||
}
|
||||
}
|
||||
if p.Policy.Ssh.Host.Deny != nil {
|
||||
ops.SSH.Host.DeniedNames = &policy.SSHNameOptions{
|
||||
DNSDomains: p.Policy.Ssh.Host.Deny.Dns,
|
||||
IPRanges: p.Policy.Ssh.Host.Deny.Ips,
|
||||
}
|
||||
}
|
||||
}
|
||||
if p.Policy.Ssh.User != nil {
|
||||
ops.SSH.User = &policy.SSHUserCertificateOptions{}
|
||||
if p.Policy.Ssh.User.Allow != nil {
|
||||
ops.SSH.User.AllowedNames = &policy.SSHNameOptions{
|
||||
EmailAddresses: p.Policy.Ssh.User.Allow.Emails,
|
||||
Principals: p.Policy.Ssh.User.Allow.Principals,
|
||||
}
|
||||
}
|
||||
if p.Policy.Ssh.User.Deny != nil {
|
||||
ops.SSH.User.DeniedNames = &policy.SSHNameOptions{
|
||||
EmailAddresses: p.Policy.Ssh.User.Deny.Emails,
|
||||
Principals: p.Policy.Ssh.User.Deny.Principals,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ops
|
||||
}
|
||||
|
||||
|
@ -436,6 +491,7 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
|
|||
|
||||
pc := &provisioner.Claims{
|
||||
DisableRenewal: &c.DisableRenewal,
|
||||
AllowRenewAfterExpiry: &c.AllowRenewAfterExpiry,
|
||||
}
|
||||
|
||||
var err error
|
||||
|
@ -473,12 +529,18 @@ func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims {
|
|||
}
|
||||
|
||||
disableRenewal := config.DefaultDisableRenewal
|
||||
allowRenewAfterExpiry := config.DefaultAllowRenewAfterExpiry
|
||||
|
||||
if c.DisableRenewal != nil {
|
||||
disableRenewal = *c.DisableRenewal
|
||||
}
|
||||
if c.AllowRenewAfterExpiry != nil {
|
||||
allowRenewAfterExpiry = *c.AllowRenewAfterExpiry
|
||||
}
|
||||
|
||||
lc := &linkedca.Claims{
|
||||
DisableRenewal: disableRenewal,
|
||||
AllowRenewAfterExpiry: allowRenewAfterExpiry,
|
||||
}
|
||||
|
||||
if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil {
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/certificates/authority/config"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/db"
|
||||
|
@ -241,6 +242,45 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
|
|||
return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", certTpl.CertType)
|
||||
}
|
||||
|
||||
switch certTpl.CertType {
|
||||
case ssh.UserCert:
|
||||
// when no user policy engine is configured, but a host policy engine is
|
||||
// configured, the user certificate is denied.
|
||||
if a.sshUserPolicy == nil && a.sshHostPolicy != nil {
|
||||
return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign ssh user certificates"), "authority.SignSSH: error creating ssh user certificate")
|
||||
}
|
||||
if a.sshUserPolicy != nil {
|
||||
allowed, err := a.sshUserPolicy.ArePrincipalsAllowed(certTpl)
|
||||
if err != nil {
|
||||
return nil, errs.InternalServerErr(err,
|
||||
errs.WithMessage("authority.SignSSH: error creating ssh user certificate"),
|
||||
)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign"), "authority.SignSSH: error creating ssh user certificate")
|
||||
}
|
||||
}
|
||||
case ssh.HostCert:
|
||||
// when no host policy engine is configured, but a user policy engine is
|
||||
// configured, the host certificate is denied.
|
||||
if a.sshHostPolicy == nil && a.sshUserPolicy != nil {
|
||||
return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign ssh host certificates"), "authority.SignSSH: error creating ssh user certificate")
|
||||
}
|
||||
if a.sshHostPolicy != nil {
|
||||
allowed, err := a.sshHostPolicy.ArePrincipalsAllowed(certTpl)
|
||||
if err != nil {
|
||||
return nil, errs.InternalServerErr(err,
|
||||
errs.WithMessage("authority.SignSSH: error creating ssh host certificate"),
|
||||
)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign"), "authority.SignSSH: error creating ssh host certificate")
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", certTpl.CertType)
|
||||
}
|
||||
|
||||
// Sign certificate.
|
||||
cert, err := sshutil.CreateCertificate(certTpl, signer)
|
||||
if err != nil {
|
||||
|
|
|
@ -191,6 +191,22 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
|
|||
}
|
||||
}
|
||||
|
||||
// Check if authority is allowed to sign the certificate
|
||||
var allowedToSign bool
|
||||
if allowedToSign, err = a.isAllowedToSign(leaf); err != nil {
|
||||
return nil, errs.InternalServerErr(err,
|
||||
errs.WithKeyVal("csr", csr),
|
||||
errs.WithKeyVal("signOptions", signOpts),
|
||||
errs.WithMessage("error creating certificate"),
|
||||
)
|
||||
}
|
||||
if !allowedToSign {
|
||||
return nil, errs.ApplyOptions(
|
||||
errs.ForbiddenErr(errors.New("authority not allowed to sign"), "error creating certificate"),
|
||||
opts...,
|
||||
)
|
||||
}
|
||||
|
||||
// Sign certificate
|
||||
lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate))
|
||||
resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{
|
||||
|
@ -214,6 +230,21 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
|
|||
return fullchain, nil
|
||||
}
|
||||
|
||||
// isAllowedToSign checks if the Authority is allowed to sign the X.509 certificate.
|
||||
// It first checks if the certificate contains an admin subject that exists in the
|
||||
// collection of admins. The CA is always allowed to sign those. If the cert contains
|
||||
// different names and a policy is configured, the policy will be executed against
|
||||
// the cert to see if the CA is allowed to sign it.
|
||||
func (a *Authority) isAllowedToSign(cert *x509.Certificate) (bool, error) {
|
||||
|
||||
// if no policy is configured, the cert is implicitly allowed
|
||||
if a.x509Policy == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return a.x509Policy.AreCertificateNamesAllowed(cert)
|
||||
}
|
||||
|
||||
// Renew creates a new Certificate identical to the old certificate, except
|
||||
// with a validity window that begins 'now'.
|
||||
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {
|
||||
|
|
|
@ -757,7 +757,7 @@ func TestAuthority_Renew(t *testing.T) {
|
|||
|
||||
now := time.Now().UTC()
|
||||
nb1 := now.Add(-time.Minute * 7)
|
||||
na1 := now
|
||||
na1 := now.Add(time.Hour)
|
||||
so := &provisioner.SignOptions{
|
||||
NotBefore: provisioner.NewTimeDuration(nb1),
|
||||
NotAfter: provisioner.NewTimeDuration(na1),
|
||||
|
@ -798,7 +798,20 @@ func TestAuthority_Renew(t *testing.T) {
|
|||
"fail/unauthorized": func() (*renewTest, error) {
|
||||
return &renewTest{
|
||||
cert: certNoRenew,
|
||||
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"),
|
||||
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||
code: http.StatusUnauthorized,
|
||||
}, nil
|
||||
},
|
||||
"fail/WithAuthorizeRenewFunc": func() (*renewTest, error) {
|
||||
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
|
||||
return errs.Unauthorized("not authorized")
|
||||
}))
|
||||
aa.x509CAService = a.x509CAService
|
||||
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
|
||||
return &renewTest{
|
||||
auth: aa,
|
||||
cert: cert,
|
||||
err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"),
|
||||
code: http.StatusUnauthorized,
|
||||
}, nil
|
||||
},
|
||||
|
@ -820,6 +833,17 @@ func TestAuthority_Renew(t *testing.T) {
|
|||
cert: cert,
|
||||
}, nil
|
||||
},
|
||||
"ok/WithAuthorizeRenewFunc": func() (*renewTest, error) {
|
||||
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
|
||||
return nil
|
||||
}))
|
||||
aa.x509CAService = a.x509CAService
|
||||
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
|
||||
return &renewTest{
|
||||
auth: aa,
|
||||
cert: cert,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
for name, genTestCase := range tests {
|
||||
|
@ -856,7 +880,7 @@ func TestAuthority_Renew(t *testing.T) {
|
|||
|
||||
expiry := now.Add(time.Minute * 7)
|
||||
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
|
||||
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
|
||||
|
||||
tmplt := a.config.AuthorityConfig.Template
|
||||
assert.Equals(t, leaf.Subject.String(),
|
||||
|
@ -956,7 +980,7 @@ func TestAuthority_Rekey(t *testing.T) {
|
|||
|
||||
now := time.Now().UTC()
|
||||
nb1 := now.Add(-time.Minute * 7)
|
||||
na1 := now
|
||||
na1 := now.Add(time.Hour)
|
||||
so := &provisioner.SignOptions{
|
||||
NotBefore: provisioner.NewTimeDuration(nb1),
|
||||
NotAfter: provisioner.NewTimeDuration(na1),
|
||||
|
@ -998,7 +1022,7 @@ func TestAuthority_Rekey(t *testing.T) {
|
|||
"fail/unauthorized": func() (*renewTest, error) {
|
||||
return &renewTest{
|
||||
cert: certNoRenew,
|
||||
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"),
|
||||
err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
|
||||
code: http.StatusUnauthorized,
|
||||
}, nil
|
||||
},
|
||||
|
@ -1063,7 +1087,7 @@ func TestAuthority_Rekey(t *testing.T) {
|
|||
|
||||
expiry := now.Add(time.Minute * 7)
|
||||
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
|
||||
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute)))
|
||||
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
|
||||
|
||||
tmplt := a.config.AuthorityConfig.Template
|
||||
assert.Equals(t, leaf.Subject.String(),
|
||||
|
|
|
@ -408,6 +408,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
|
|||
server.ServeTLS(listener, "", "")
|
||||
}()
|
||||
defer server.Close()
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
// Create bootstrap client
|
||||
token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
|
||||
|
@ -419,7 +420,6 @@ func TestBootstrapClientServerRotation(t *testing.T) {
|
|||
|
||||
// doTest does a request that requires mTLS
|
||||
doTest := func(client *http.Client) error {
|
||||
time.Sleep(1 * time.Second)
|
||||
// test with ca
|
||||
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
|
||||
if err != nil {
|
||||
|
|
6
ca/ca.go
6
ca/ca.go
|
@ -208,7 +208,8 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
|
|||
adminDB := auth.GetAdminDatabase()
|
||||
if adminDB != nil {
|
||||
acmeAdminResponder := adminAPI.NewACMEAdminResponder()
|
||||
adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder)
|
||||
policyAdminResponder := adminAPI.NewPolicyAdminResponder(auth, adminDB)
|
||||
adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder, policyAdminResponder)
|
||||
mux.Route("/admin", func(r chi.Router) {
|
||||
adminHandler.Route(r)
|
||||
})
|
||||
|
@ -450,9 +451,6 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
|
|||
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
tlsConfig.ClientCAs = certPool
|
||||
|
||||
// Use server's most preferred ciphersuite
|
||||
tlsConfig.PreferServerCipherSuites = true
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
||||
|
||||
|
|
35
ca/client.go
35
ca/client.go
|
@ -563,6 +563,11 @@ func (c *Client) retryOnError(r *http.Response) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// GetCaURL returns the configured CA url.
|
||||
func (c *Client) GetCaURL() string {
|
||||
return c.endpoint.String()
|
||||
}
|
||||
|
||||
// GetRootCAs returns the RootCAs certificate pool from the configured
|
||||
// transport.
|
||||
func (c *Client) GetRootCAs() *x509.CertPool {
|
||||
|
@ -723,6 +728,36 @@ retry:
|
|||
return &sign, nil
|
||||
}
|
||||
|
||||
// RenewWithToken performs the renew request to the CA with the given
|
||||
// authorization token and returns the api.SignResponse struct. This method is
|
||||
// generally used to renew an expired certificate.
|
||||
func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) {
|
||||
var retried bool
|
||||
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
|
||||
req, err := http.NewRequest("POST", u.String(), http.NoBody)
|
||||
if err != nil {
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error creating request")
|
||||
}
|
||||
req.Header.Add("Authorization", "Bearer "+token)
|
||||
retry:
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; client POST %s failed", u)
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
if !retried && c.retryOnError(resp) {
|
||||
retried = true
|
||||
goto retry
|
||||
}
|
||||
return nil, readError(resp.Body)
|
||||
}
|
||||
var sign api.SignResponse
|
||||
if err := readJSON(resp.Body, &sign); err != nil {
|
||||
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error reading %s", u)
|
||||
}
|
||||
return &sign, nil
|
||||
}
|
||||
|
||||
// Rekey performs the rekey request to the CA and returns the api.SignResponse
|
||||
// struct.
|
||||
func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {
|
||||
|
|
|
@ -17,13 +17,16 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"go.step.sm/crypto/x509util"
|
||||
|
||||
"github.com/smallstep/assert"
|
||||
"github.com/smallstep/certificates/api"
|
||||
"github.com/smallstep/certificates/api/read"
|
||||
"github.com/smallstep/certificates/authority"
|
||||
"github.com/smallstep/certificates/authority/provisioner"
|
||||
"github.com/smallstep/certificates/errs"
|
||||
"go.step.sm/crypto/x509util"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -354,7 +357,7 @@ func TestClient_Sign(t *testing.T) {
|
|||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
body := new(api.SignRequest)
|
||||
if err := api.ReadJSON(req.Body, body); err != nil {
|
||||
if err := read.JSON(req.Body, body); err != nil {
|
||||
e, ok := tt.response.(error)
|
||||
assert.Fatal(t, ok, "response expected to be error type")
|
||||
api.WriteError(w, e)
|
||||
|
@ -426,7 +429,7 @@ func TestClient_Revoke(t *testing.T) {
|
|||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
body := new(api.RevokeRequest)
|
||||
if err := api.ReadJSON(req.Body, body); err != nil {
|
||||
if err := read.JSON(req.Body, body); err != nil {
|
||||
e, ok := tt.response.(error)
|
||||
assert.Fatal(t, ok, "response expected to be error type")
|
||||
api.WriteError(w, e)
|
||||
|
@ -529,6 +532,74 @@ func TestClient_Renew(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestClient_RenewWithToken(t *testing.T) {
|
||||
ok := &api.SignResponse{
|
||||
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
||||
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
|
||||
CertChainPEM: []api.Certificate{
|
||||
{Certificate: parseCertificate(certPEM)},
|
||||
{Certificate: parseCertificate(rootPEM)},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response interface{}
|
||||
responseCode int
|
||||
wantErr bool
|
||||
err error
|
||||
}{
|
||||
{"ok", ok, 200, false, nil},
|
||||
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
|
||||
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
|
||||
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
|
||||
}
|
||||
|
||||
srv := httptest.NewServer(nil)
|
||||
defer srv.Close()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Errorf("NewClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Header.Get("Authorization") != "Bearer token" {
|
||||
api.JSONStatus(w, errs.InternalServer("force"), 500)
|
||||
} else {
|
||||
api.JSONStatus(w, tt.response, tt.responseCode)
|
||||
}
|
||||
})
|
||||
|
||||
got, err := c.RenewWithToken("token")
|
||||
if (err != nil) != tt.wantErr {
|
||||
fmt.Printf("%+v", err)
|
||||
t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case err != nil:
|
||||
if got != nil {
|
||||
t.Errorf("Client.RenewWithToken() = %v, want nil", got)
|
||||
}
|
||||
|
||||
sc, ok := err.(errs.StatusCoder)
|
||||
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
|
||||
assert.Equals(t, sc.StatusCode(), tt.responseCode)
|
||||
assert.HasPrefix(t, err.Error(), tt.err.Error())
|
||||
default:
|
||||
if !reflect.DeepEqual(got, tt.response) {
|
||||
t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Rekey(t *testing.T) {
|
||||
ok := &api.SignResponse{
|
||||
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
|
||||
|
@ -1060,3 +1131,28 @@ func TestClient_SSHBastion(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_GetCaURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
caURL string
|
||||
want string
|
||||
}{
|
||||
{"ok", "https://ca.com", "https://ca.com"},
|
||||
{"ok no schema", "ca.com", "https://ca.com"},
|
||||
{"ok with port", "https://ca.com:9000", "https://ca.com:9000"},
|
||||
{"ok with version", "https://ca.com/1.0", "https://ca.com/1.0"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport))
|
||||
if err != nil {
|
||||
t.Errorf("NewClient() error = %v", err)
|
||||
return
|
||||
}
|
||||
if got := c.GetCaURL(); got != tt.want {
|
||||
t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
@ -196,7 +197,7 @@ func TestLoadClient(t *testing.T) {
|
|||
switch {
|
||||
case gotTransport.TLSClientConfig.GetClientCertificate == nil:
|
||||
t.Error("LoadClient() transport does not define GetClientCertificate")
|
||||
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()):
|
||||
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !equalPools(gotTransport.TLSClientConfig.RootCAs, wantTransport.TLSClientConfig.RootCAs):
|
||||
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
|
||||
default:
|
||||
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
|
||||
|
@ -238,3 +239,23 @@ func Test_defaultsConfig_Validate(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// nolint:staticcheck,gocritic
|
||||
func equalPools(a, b *x509.CertPool) bool {
|
||||
if reflect.DeepEqual(a, b) {
|
||||
return true
|
||||
}
|
||||
subjects := a.Subjects()
|
||||
sA := make([]string, len(subjects))
|
||||
for i := range subjects {
|
||||
sA[i] = string(subjects[i])
|
||||
}
|
||||
subjects = b.Subjects()
|
||||
sB := make([]string, len(subjects))
|
||||
for i := range subjects {
|
||||
sB[i] = string(subjects[i])
|
||||
}
|
||||
sort.Strings(sA)
|
||||
sort.Strings(sB)
|
||||
return reflect.DeepEqual(sA, sB)
|
||||
}
|
||||
|
|
|
@ -346,6 +346,8 @@ func TestIdentity_GetCertPool(t *testing.T) {
|
|||
return
|
||||
}
|
||||
if got != nil {
|
||||
// nolint:staticcheck // we don't have a different way to check
|
||||
// the certificates in the pool.
|
||||
subjects := got.Subjects()
|
||||
if !reflect.DeepEqual(subjects, tt.wantSubjects) {
|
||||
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)
|
||||
|
|
2
ca/testdata/ca.json
vendored
2
ca/testdata/ca.json
vendored
|
@ -6,7 +6,7 @@
|
|||
"password": "password",
|
||||
"address": "127.0.0.1:0",
|
||||
"dnsNames": ["127.0.0.1"],
|
||||
"logger": {"format": "text"},
|
||||
"_logger": {"format": "text"},
|
||||
"tls": {
|
||||
"minVersion": 1.2,
|
||||
"maxVersion": 1.3,
|
||||
|
|
2
ca/testdata/federated-ca.json
vendored
2
ca/testdata/federated-ca.json
vendored
|
@ -6,7 +6,7 @@
|
|||
"password": "asdf",
|
||||
"address": "127.0.0.1:0",
|
||||
"dnsNames": ["127.0.0.1"],
|
||||
"logger": {"format": "text"},
|
||||
"_logger": {"format": "text"},
|
||||
"tls": {
|
||||
"minVersion": 1.2,
|
||||
"maxVersion": 1.2,
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue