Merge pull request #847 from smallstep/herman/allow-deny-next

Refactor allow/deny (WIP)
This commit is contained in:
Herman Slatman 2022-03-24 13:13:19 +01:00 committed by GitHub
commit c45d177d52
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
112 changed files with 4101 additions and 1274 deletions

View file

@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: [ '1.15', '1.16', '1.17' ] go: [ '1.17', '1.18' ]
outputs: outputs:
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
steps: steps:
@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.44.0' version: 'v1.45.0'
# Optional: working directory, useful for monorepos # Optional: working directory, useful for monorepos
# working-directory: somedir # working-directory: somedir
@ -106,7 +106,7 @@ jobs:
name: Set up Go name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.17 go-version: 1.18
- -
name: APT Install name: APT Install
id: aptInstall id: aptInstall
@ -159,7 +159,7 @@ jobs:
name: Setup Go name: Setup Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: '1.17' go-version: '1.18'
- -
name: Install cosign name: Install cosign
uses: sigstore/cosign-installer@v1.1.0 uses: sigstore/cosign-installer@v1.1.0

View file

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: [ '1.16', '1.17' ] go: [ '1.17', '1.18' ]
steps: steps:
- -
name: Checkout name: Checkout
@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.44.0' version: 'v1.45.0'
# Optional: working directory, useful for monorepos # Optional: working directory, useful for monorepos
# working-directory: somedir # working-directory: somedir
@ -58,7 +58,7 @@ jobs:
run: V=1 make ci run: V=1 make ci
- -
name: Codecov name: Codecov
if: matrix.go == '1.17' if: matrix.go == '1.18'
uses: codecov/codecov-action@v1.2.1 uses: codecov/codecov-action@v1.2.1
with: with:
file: ./coverage.out # optional file: ./coverage.out # optional

1
.gitignore vendored
View file

@ -19,3 +19,4 @@ coverage.txt
output output
vendor vendor
.idea .idea
.envrc

View file

@ -19,6 +19,7 @@ builds:
- linux_386 - linux_386
- linux_amd64 - linux_amd64
- linux_arm64 - linux_arm64
- linux_arm_5
- linux_arm_6 - linux_arm_6
- linux_arm_7 - linux_arm_7
- windows_amd64 - windows_amd64
@ -39,6 +40,7 @@ builds:
- linux_386 - linux_386
- linux_amd64 - linux_amd64
- linux_arm64 - linux_arm64
- linux_arm_5
- linux_arm_6 - linux_arm_6
- linux_arm_7 - linux_arm_7
- windows_amd64 - windows_amd64
@ -59,6 +61,7 @@ builds:
- linux_386 - linux_386
- linux_amd64 - linux_amd64
- linux_arm64 - linux_arm64
- linux_arm_5
- linux_arm_6 - linux_arm_6
- linux_arm_7 - linux_arm_7
- windows_amd64 - windows_amd64

View file

@ -6,7 +6,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased - 0.18.3] - DATE ## [Unreleased - 0.18.3] - DATE
### Added ### Added
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`.
### Changed ### Changed
- Made SCEP CA URL paths dynamic
- Support two latest versions of Go (1.17, 1.18)
### Deprecated ### Deprecated
### Removed ### Removed
### Fixed ### Fixed
@ -15,6 +18,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [0.18.2] - 2022-03-01 ## [0.18.2] - 2022-03-01
### Added ### Added
- Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner. - Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner.
- [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering
out database drivers using Go tags. For example, using the Go flag
`--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx
driver for PostgreSQL.
### Changed ### Changed
- IPv6 addresses are normalized as IP addresses instead of hostnames. - IPv6 addresses are normalized as IP addresses instead of hostnames.
- More descriptive JWK decryption error message. - More descriptive JWK decryption error message.

View file

@ -13,6 +13,7 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
) )
@ -104,10 +105,13 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
// management of allowed/denied names based on just the name, without having bound to EAB. Still, // management of allowed/denied names based on just the name, without having bound to EAB. Still,
// EAB is not illogical, because that's the way Accounts are connected to an external system and // EAB is not illogical, because that's the way Accounts are connected to an external system and
// thus make sense to also set the allowed/denied names based on that info. // thus make sense to also set the allowed/denied names based on that info.
// TODO: also perform check on the authority level here already, so that challenges are not performed
// and after that the CA fails to sign it. (i.e. h.ca function?)
for _, identifier := range nor.Identifiers { for _, identifier := range nor.Identifiers {
// TODO: gather all errors, so that we can build subproblems; include the nor.Validate() error here too, like in example? // TODO: gather all errors, so that we can build subproblems; include the nor.Validate() error here too, like in example?
err = prov.AuthorizeOrderIdentifier(ctx, identifier.Value) orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value}
err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier)
if err != nil { if err != nil {
api.WriteError(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized")) api.WriteError(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return return

View file

@ -30,7 +30,7 @@ var clock Clock
// Provisioner is an interface that implements a subset of the provisioner.Interface -- // Provisioner is an interface that implements a subset of the provisioner.Interface --
// only those methods required by the ACME api/authority. // only those methods required by the ACME api/authority.
type Provisioner interface { type Provisioner interface {
AuthorizeOrderIdentifier(ctx context.Context, identifier string) error AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
AuthorizeRevoke(ctx context.Context, token string) error AuthorizeRevoke(ctx context.Context, token string) error
GetID() string GetID() string
@ -45,7 +45,7 @@ type MockProvisioner struct {
Merr error Merr error
MgetID func() string MgetID func() string
MgetName func() string MgetName func() string
MauthorizeOrderIdentifier func(ctx context.Context, identifier string) error MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
MauthorizeRevoke func(ctx context.Context, token string) error MauthorizeRevoke func(ctx context.Context, token string) error
MdefaultTLSCertDuration func() time.Duration MdefaultTLSCertDuration func() time.Duration
@ -61,7 +61,7 @@ func (m *MockProvisioner) GetName() string {
} }
// AuthorizeOrderIdentifiers mock // AuthorizeOrderIdentifiers mock
func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier string) error { func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
if m.MauthorizeOrderIdentifier != nil { if m.MauthorizeOrderIdentifier != nil {
return m.MauthorizeOrderIdentifier(ctx, identifier) return m.MauthorizeOrderIdentifier(ctx, identifier)
} }

View file

@ -20,6 +20,7 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -33,6 +34,7 @@ type Authority interface {
// context specifies the Authorize[Sign|Revoke|etc.] method. // context specifies the Authorize[Sign|Revoke|etc.] method.
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
AuthorizeSign(ott string) ([]provisioner.SignOption, error) AuthorizeSign(ott string) ([]provisioner.SignOption, error)
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
GetTLSOptions() *config.TLSOptions GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error) Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
@ -43,7 +45,7 @@ type Authority interface {
GetProvisioners(cursor string, limit int) (provisioner.List, string, error) GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
Revoke(context.Context, *authority.RevokeOptions) error Revoke(context.Context, *authority.RevokeOptions) error
GetEncryptedKey(kid string) (string, error) GetEncryptedKey(kid string) (string, error)
GetRoots() (federation []*x509.Certificate, err error) GetRoots() ([]*x509.Certificate, error)
GetFederation() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error)
Version() authority.Version Version() authority.Version
} }

View file

@ -13,6 +13,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
@ -27,14 +28,17 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/x509util"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ssh"
) )
const ( const (
@ -171,6 +175,7 @@ type mockAuthority struct {
ret1, ret2 interface{} ret1, ret2 interface{}
err error err error
authorizeSign func(ott string) ([]provisioner.SignOption, error) authorizeSign func(ott string) ([]provisioner.SignOption, error)
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
getTLSOptions func() *authority.TLSOptions getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error) root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
@ -208,6 +213,13 @@ func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, err
return m.ret1.([]provisioner.SignOption), m.err return m.ret1.([]provisioner.SignOption), m.err
} }
func (m *mockAuthority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
if m.authorizeRenewToken != nil {
return m.authorizeRenewToken(ctx, ott)
}
return m.ret1.(*x509.Certificate), m.err
}
func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions {
if m.getTLSOptions != nil { if m.getTLSOptions != nil {
return m.getTLSOptions() return m.getTLSOptions()
@ -920,48 +932,141 @@ func Test_caHandler_Renew(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
// Prepare root and leaf for renew after expiry test.
now := time.Now()
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
leafPub, leafPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
root := &x509.Certificate{
Subject: pkix.Name{CommonName: "Test Root CA"},
PublicKey: rootPub,
KeyUsage: x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IsCA: true,
NotBefore: now.Add(-2 * time.Hour),
NotAfter: now.Add(time.Hour),
}
root, err = x509util.CreateCertificate(root, root, rootPub, rootPriv)
if err != nil {
t.Fatal(err)
}
expiredLeaf := &x509.Certificate{
Subject: pkix.Name{CommonName: "Leaf certificate"},
PublicKey: leafPub,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
EmailAddresses: []string{"test@example.org"},
}
expiredLeaf, err = x509util.CreateCertificate(expiredLeaf, root, leafPub, rootPriv)
if err != nil {
t.Fatal(err)
}
// Generate renew after expiry token
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("x5cInsecure", []string{base64.StdEncoding.EncodeToString(expiredLeaf.Raw)})
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: leafPriv}, so)
if err != nil {
t.Fatal(err)
}
generateX5cToken := func(claims jose.Claims) string {
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
if err != nil {
t.Fatal(err)
}
return s
}
tests := []struct { tests := []struct {
name string name string
tls *tls.ConnectionState tls *tls.ConnectionState
header http.Header
cert *x509.Certificate cert *x509.Certificate
root *x509.Certificate root *x509.Certificate
err error err error
statusCode int statusCode int
}{ }{
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"ok", cs, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, http.StatusBadRequest}, {"ok renew after expiry", &tls.ConnectionState{}, http.Header{
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
})},
}, expiredLeaf, root, nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, nil, http.StatusBadRequest},
{"renew error", cs, nil, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
{"fail expired token", &tls.ConnectionState{}, http.Header{
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
})},
}, expiredLeaf, root, errs.Forbidden("an error"), http.StatusUnauthorized},
{"fail invalid root", &tls.ConnectionState{}, http.Header{
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
})},
}, expiredLeaf, parseCertificate(rootPEM), errs.Forbidden("an error"), http.StatusUnauthorized},
} }
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ h := New(&mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err, ret1: tt.cert, ret2: tt.root, err: tt.err,
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
if err != nil {
return nil, errs.Unauthorized(err.Error())
}
var claims jose.Claims
if err := jwt.Claims(chain[0][0].PublicKey, &claims); err != nil {
return nil, errs.Unauthorized(err.Error())
}
if err := claims.ValidateWithLeeway(jose.Expected{
Time: now,
}, time.Minute); err != nil {
return nil, errs.Unauthorized(err.Error())
}
return chain[0][0], nil
},
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) }).(*caHandler)
req := httptest.NewRequest("POST", "http://example.com/renew", nil) req := httptest.NewRequest("POST", "http://example.com/renew", nil)
req.TLS = tt.tls req.TLS = tt.tls
req.Header = tt.header
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Renew(logging.NewResponseLogger(w), req) h.Renew(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode { res := w.Result()
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) defer res.Body.Close()
}
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Renew unexpected error = %v", err) t.Errorf("caHandler.Renew unexpected error = %v", err)
} }
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
t.Errorf("%s", body)
}
if tt.statusCode < http.StatusBadRequest { if tt.statusCode < http.StatusBadRequest {
expected := []byte(`{"crt":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `",` +
`"ca":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `",` +
`"certChain":["` +
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `","` +
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `"]}`)
if !bytes.Equal(bytes.TrimSpace(body), expected) { if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) t.Errorf("caHandler.Root Body = \n%s, wants \n%s", body, expected)
} }
} }
}) })

View file

@ -7,7 +7,9 @@ import (
"os" "os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
@ -59,6 +61,6 @@ func WriteError(w http.ResponseWriter, err error) {
} }
if err := json.NewEncoder(w).Encode(err); err != nil { if err := json.NewEncoder(w).Encode(err); err != nil {
LogError(w, err) log.Error(w, err)
} }
} }

47
api/log/log.go Normal file
View file

@ -0,0 +1,47 @@
// Package log implements API-related logging helpers.
package log
import (
"log"
"net/http"
"github.com/smallstep/certificates/logging"
)
// Error adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package.
func Error(rw http.ResponseWriter, err error) {
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
} else {
log.Println(err)
}
}
// EnabledResponse log the response object if it implements the EnableLogger
// interface.
func EnabledResponse(rw http.ResponseWriter, v interface{}) {
type enableLogger interface {
ToLog() (interface{}, error)
}
if el, ok := v.(enableLogger); ok {
out, err := el.ToLog()
if err != nil {
Error(rw, err)
return
}
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"response": out,
})
} else {
log.Println(out)
}
}
}

44
api/log/log_test.go Normal file
View file

@ -0,0 +1,44 @@
package log
import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/smallstep/certificates/logging"
)
func TestError(t *testing.T) {
theError := errors.New("the error")
type args struct {
rw http.ResponseWriter
err error
}
tests := []struct {
name string
args args
withFields bool
}{
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Error(tt.args.rw, tt.args.err)
if tt.withFields {
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
fields := rl.Fields()
if !reflect.DeepEqual(fields["error"], theError) {
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
}
} else {
t.Error("ResponseWriter does not implement logging.ResponseLogger")
}
}
})
}
}

31
api/read/read.go Normal file
View file

@ -0,0 +1,31 @@
// Package read implements request object readers.
package read
import (
"encoding/json"
"io"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/errs"
)
// JSON reads JSON from the request body and stores it in the value
// pointed by v.
func JSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
}
return nil
}
// ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
}
return protojson.Unmarshal(data, m)
}

46
api/read/read_test.go Normal file
View file

@ -0,0 +1,46 @@
package read
import (
"io"
"reflect"
"strings"
"testing"
"github.com/smallstep/certificates/errs"
)
func TestJSON(t *testing.T) {
type args struct {
r io.Reader
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := JSON(tt.args.r, &tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
e, ok := err.(*errs.Error)
if ok {
if code := e.StatusCode(); code != 400 {
t.Errorf("error.StatusCode() = %v, wants 400", code)
}
} else {
t.Errorf("error type = %T, wants *Error", err)
}
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
}
})
}
}

View file

@ -3,6 +3,7 @@ package api
import ( import (
"net/http" "net/http"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -32,7 +33,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
} }
var body RekeyRequest var body RekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -1,20 +1,28 @@
package api package api
import ( import (
"crypto/x509"
"net/http" "net/http"
"strings"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
const (
authorizationHeader = "Authorization"
bearerScheme = "Bearer"
)
// Renew uses the information of certificate in the TLS connection to create a // Renew uses the information of certificate in the TLS connection to create a
// new one. // new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { cert, err := h.getPeerCertificate(r)
WriteError(w, errs.BadRequest("missing client certificate")) if err != nil {
WriteError(w, err)
return return
} }
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0]) certChain, err := h.Authority.Renew(cert)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return return
@ -33,3 +41,15 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
TLSOptions: h.Authority.GetTLSOptions(), TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated) }, http.StatusCreated)
} }
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
return r.TLS.PeerCertificates[0], nil
}
if s := r.Header.Get(authorizationHeader); s != "" {
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
}
}
return nil, errs.BadRequest("missing client certificate")
}

View file

@ -4,11 +4,13 @@ import (
"context" "context"
"net/http" "net/http"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ocsp"
) )
// RevokeResponse is the response object that returns the health of the server. // RevokeResponse is the response object that returns the health of the server.
@ -48,7 +50,7 @@ func (r *RevokeRequest) Validate() (err error) {
// TODO: Add CRL and OCSP support. // TODO: Add CRL and OCSP support.
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -13,6 +13,7 @@ import (
"testing" "testing"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -49,7 +50,7 @@ type SignResponse struct {
// information in the certificate request. // information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -9,12 +9,14 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"golang.org/x/crypto/ssh"
) )
// SSHAuthority is the interface implemented by a SSH CA authority. // SSHAuthority is the interface implemented by a SSH CA authority.
@ -249,7 +251,7 @@ type SSHBastionResponse struct {
// the request. // the request.
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
@ -393,7 +395,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
// and servers. // and servers.
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
@ -425,7 +427,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest var body SSHCheckPrincipalRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
@ -464,7 +466,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
// SSHBastion provides returns the bastion configured if any. // SSHBastion provides returns the bastion configured if any.
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -4,9 +4,11 @@ import (
"net/http" "net/http"
"time" "time"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
) )
// SSHRekeyRequest is the request body of an SSH certificate request. // SSHRekeyRequest is the request body of an SSH certificate request.
@ -38,7 +40,7 @@ type SSHRekeyResponse struct {
// the request. // the request.
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -6,6 +6,8 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -36,7 +38,7 @@ type SSHRenewResponse struct {
// the request. // the request.
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -3,11 +3,13 @@ package api
import ( import (
"net/http" "net/http"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ocsp"
) )
// SSHRevokeResponse is the response object that returns the health of the server. // SSHRevokeResponse is the response object that returns the health of the server.
@ -47,7 +49,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
// NOTE: currently only Passive revocation is supported. // NOTE: currently only Passive revocation is supported.
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -18,12 +18,13 @@ import (
"testing" "testing"
"time" "time"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"golang.org/x/crypto/ssh"
) )
var ( var (

View file

@ -2,14 +2,15 @@ package api
import ( import (
"encoding/json" "encoding/json"
"errors"
"io" "io"
"log"
"net/http" "net/http"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/api/log"
) )
// EnableLogger is an interface that enables response logging for an object. // EnableLogger is an interface that enables response logging for an object.
@ -17,38 +18,6 @@ type EnableLogger interface {
ToLog() (interface{}, error) ToLog() (interface{}, error)
} }
// LogError adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package.
func LogError(rw http.ResponseWriter, err error) {
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
} else {
log.Println(err)
}
}
// LogEnabledResponse log the response object if it implements the EnableLogger
// interface.
func LogEnabledResponse(rw http.ResponseWriter, v interface{}) {
if el, ok := v.(EnableLogger); ok {
out, err := el.ToLog()
if err != nil {
LogError(rw, err)
return
}
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"response": out,
})
} else {
log.Println(out)
}
}
}
// JSON writes the passed value into the http.ResponseWriter. // JSON writes the passed value into the http.ResponseWriter.
func JSON(w http.ResponseWriter, v interface{}) { func JSON(w http.ResponseWriter, v interface{}) {
JSONStatus(w, v, http.StatusOK) JSONStatus(w, v, http.StatusOK)
@ -60,10 +29,19 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status) w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(v); err != nil { if err := json.NewEncoder(w).Encode(v); err != nil {
LogError(w, err) log.Error(w, err)
return return
} }
LogEnabledResponse(w, v)
log.EnabledResponse(w, v)
}
// JSONNotFound writes a HTTP Not Found response with empty body.
func JSONNotFound(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound)
log.EnabledResponse(w, nil)
} }
// ProtoJSON writes the passed value into the http.ResponseWriter. // ProtoJSON writes the passed value into the http.ResponseWriter.
@ -79,31 +57,62 @@ func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
b, err := protojson.Marshal(m) b, err := protojson.Marshal(m)
if err != nil { if err != nil {
LogError(w, err) log.Error(w, err)
return return
} }
if _, err := w.Write(b); err != nil { if _, err := w.Write(b); err != nil {
LogError(w, err) log.Error(w, err)
return return
} }
//LogEnabledResponse(w, v)
// log.EnabledResponse(w, v)
} }
// ReadJSON reads JSON from the request body and stores it in the value // ReadProtoJSONWithCheck reads JSON from the request body and stores it in the value
// pointed by v. // pointed by v. TODO(hs): move this to and integrate with render package.
func ReadJSON(r io.Reader, v interface{}) error { func ReadProtoJSONWithCheck(w http.ResponseWriter, r io.Reader, m proto.Message) bool {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
}
return nil
}
// ReadProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ReadProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r) data, err := io.ReadAll(r)
if err != nil { if err != nil {
return errs.BadRequestErr(err, "error reading request body") var wrapper = struct {
Status int `json:"code"`
Message string `json:"message"`
}{
Status: http.StatusBadRequest,
Message: err.Error(),
}
errData, err := json.Marshal(wrapper)
if err != nil {
panic(err)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write(errData)
return false
} }
return protojson.Unmarshal(data, m) if err := protojson.Unmarshal(data, m); err != nil {
if errors.Is(err, proto.Error) {
var wrapper = struct {
Message string `json:"message"`
}{
Message: err.Error(),
}
errData, err := json.Marshal(wrapper)
if err != nil {
panic(err)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write(errData)
return false
}
// fallback to the default error writer
WriteError(w, err)
return false
}
return true
} }

View file

@ -1,50 +1,13 @@
package api package api
import ( import (
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"strings"
"testing" "testing"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
func TestLogError(t *testing.T) {
theError := errors.New("the error")
type args struct {
rw http.ResponseWriter
err error
}
tests := []struct {
name string
args args
withFields bool
}{
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
LogError(tt.args.rw, tt.args.err)
if tt.withFields {
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
fields := rl.Fields()
if !reflect.DeepEqual(fields["error"], theError) {
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
}
} else {
t.Error("ResponseWriter does not implement logging.ResponseLogger")
}
}
})
}
}
func TestJSON(t *testing.T) { func TestJSON(t *testing.T) {
type args struct { type args struct {
rw http.ResponseWriter rw http.ResponseWriter
@ -88,38 +51,3 @@ func TestJSON(t *testing.T) {
}) })
} }
} }
func TestReadJSON(t *testing.T) {
type args struct {
r io.Reader
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ReadJSON(tt.args.r, &tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("ReadJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
e, ok := err.(*errs.Error)
if ok {
if code := e.StatusCode(); code != 400 {
t.Errorf("error.StatusCode() = %v, wants 400", code)
}
} else {
t.Errorf("error type = %T, wants *Error", err)
}
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("ReadJSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
}
})
}
}

View file

@ -14,7 +14,7 @@ import (
const ( const (
// provisionerContextKey provisioner key // provisionerContextKey provisioner key
provisionerContextKey = ContextKey("provisioner") provisionerContextKey = admin.ContextKey("provisioner")
) )
// CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests // CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests

View file

@ -5,10 +5,13 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
) )
type adminAuthority interface { type adminAuthority interface {
@ -25,6 +28,10 @@ type adminAuthority interface {
LoadProvisionerByID(id string) (provisioner.Interface, error) LoadProvisionerByID(id string) (provisioner.Interface, error)
UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error
RemoveProvisioner(ctx context.Context, id string) error RemoveProvisioner(ctx context.Context, id string) error
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
RemoveAuthorityPolicy(ctx context.Context) error
} }
// CreateAdminRequest represents the body for a CreateAdmin request. // CreateAdminRequest represents the body for a CreateAdmin request.
@ -112,7 +119,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
// CreateAdmin creates a new admin. // CreateAdmin creates a new admin.
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
var body CreateAdminRequest var body CreateAdminRequest
if err := api.ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return return
} }
@ -156,7 +163,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
// UpdateAdmin updates an existing admin. // UpdateAdmin updates an existing admin.
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
var body UpdateAdminRequest var body UpdateAdminRequest
if err := api.ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return return
} }

View file

@ -14,11 +14,13 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/timestamppb"
) )
type mockAdminAuthority struct { type mockAdminAuthority struct {
@ -37,6 +39,11 @@ type mockAdminAuthority struct {
MockLoadProvisionerByID func(id string) (provisioner.Interface, error) MockLoadProvisionerByID func(id string) (provisioner.Interface, error)
MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error
MockRemoveProvisioner func(ctx context.Context, id string) error MockRemoveProvisioner func(ctx context.Context, id string) error
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) (*linkedca.Policy, error)
MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
MockRemoveAuthorityPolicy func(ctx context.Context) error
} }
func (m *mockAdminAuthority) IsAdminAPIEnabled() bool { func (m *mockAdminAuthority) IsAdminAPIEnabled() bool {
@ -130,6 +137,22 @@ func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) e
return m.MockErr return m.MockErr
} }
func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
return nil, errors.New("not implemented yet")
}
func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
return nil, errors.New("not implemented yet")
}
func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
return nil, errors.New("not implemented yet")
}
func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error {
return errors.New("not implemented yet")
}
func TestCreateAdminRequest_Validate(t *testing.T) { func TestCreateAdminRequest_Validate(t *testing.T) {
type fields struct { type fields struct {
Subject string Subject string

View file

@ -8,24 +8,27 @@ import (
// Handler is the Admin API request handler. // Handler is the Admin API request handler.
type Handler struct { type Handler struct {
adminDB admin.DB adminDB admin.DB
auth adminAuthority auth adminAuthority
acmeDB acme.DB acmeDB acme.DB
acmeResponder acmeAdminResponderInterface acmeResponder acmeAdminResponderInterface
policyResponder policyAdminResponderInterface
} }
// NewHandler returns a new Authority Config Handler. // NewHandler returns a new Authority Config Handler.
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler { func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface, policyResponder policyAdminResponderInterface) api.RouterHandler {
return &Handler{ return &Handler{
auth: auth, auth: auth,
adminDB: adminDB, adminDB: adminDB,
acmeDB: acmeDB, acmeDB: acmeDB,
acmeResponder: acmeResponder, acmeResponder: acmeResponder,
policyResponder: policyResponder,
} }
} }
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) { func (h *Handler) Route(r api.Router) {
authnz := func(next nextHTTP) nextHTTP { authnz := func(next nextHTTP) nextHTTP {
return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next))
} }
@ -34,6 +37,14 @@ func (h *Handler) Route(r api.Router) {
return h.requireEABEnabled(next) return h.requireEABEnabled(next)
} }
enabledInStandalone := func(next nextHTTP) nextHTTP {
return h.checkAction(next, true)
}
disabledInStandalone := func(next nextHTTP) nextHTTP {
return h.checkAction(next, false)
}
// Provisioners // Provisioners
r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner)) r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner))
r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners)) r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners))
@ -53,4 +64,24 @@ func (h *Handler) Route(r api.Router) {
r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys)))
r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey))) r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey)))
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey))) r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey)))
// Policy - Authority
r.MethodFunc("GET", "/policy", authnz(enabledInStandalone(h.policyResponder.GetAuthorityPolicy)))
r.MethodFunc("POST", "/policy", authnz(enabledInStandalone(h.policyResponder.CreateAuthorityPolicy)))
r.MethodFunc("PUT", "/policy", authnz(enabledInStandalone(h.policyResponder.UpdateAuthorityPolicy)))
r.MethodFunc("DELETE", "/policy", authnz(enabledInStandalone(h.policyResponder.DeleteAuthorityPolicy)))
// Policy - Provisioner
//r.MethodFunc("GET", "/provisioners/{name}/policy", noauth(h.policyResponder.GetProvisionerPolicy))
r.MethodFunc("GET", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.GetProvisionerPolicy)))
r.MethodFunc("POST", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.CreateProvisionerPolicy)))
r.MethodFunc("PUT", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.UpdateProvisionerPolicy)))
r.MethodFunc("DELETE", "/provisioners/{name}/policy", authnz(disabledInStandalone(h.policyResponder.DeleteProvisionerPolicy)))
// Policy - ACME Account
// TODO: ensure we don't clash with eab; might want to change eab paths slightly (as long as we don't have it released completely; needs changes in adminClient too)
r.MethodFunc("GET", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.GetACMEAccountPolicy)))
r.MethodFunc("POST", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.CreateACMEAccountPolicy)))
r.MethodFunc("PUT", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.UpdateACMEAccountPolicy)))
r.MethodFunc("DELETE", "/acme/{provisionerName}/{accountID}/policy", authnz(disabledInStandalone(h.policyResponder.DeleteACMEAccountPolicy)))
} }

View file

@ -1,11 +1,13 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/admin/db/nosql"
) )
type nextHTTP = func(http.ResponseWriter, *http.Request) type nextHTTP = func(http.ResponseWriter, *http.Request)
@ -26,6 +28,7 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token. // extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization") tok := r.Header.Get("Authorization")
if tok == "" { if tok == "" {
api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType, api.WriteError(w, admin.NewError(admin.ErrorUnauthorizedType,
@ -39,16 +42,30 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
return return
} }
ctx := context.WithValue(r.Context(), adminContextKey, adm) ctx := linkedca.WithAdmin(r.Context(), adm)
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }
// ContextKey is the key type for storing and searching for ACME request // checkAction checks if an action is supported in standalone or not
// essentials in the context of a request. func (h *Handler) checkAction(next nextHTTP, supportedInStandalone bool) nextHTTP {
type ContextKey string return func(w http.ResponseWriter, r *http.Request) {
const ( // actions allowed in standalone mode are always supported
// adminContextKey account key if supportedInStandalone {
adminContextKey = ContextKey("admin") next(w, r)
) return
}
// when not in standalone mode and using a nosql.DB backend,
// actions are not supported
if _, ok := h.adminDB.(*nosql.DB); ok {
api.WriteError(w, admin.NewError(admin.ErrorNotImplementedType,
"operation not supported in standalone mode"))
return
}
// continue to next http handler
next(w, r)
}
}

View file

@ -152,7 +152,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
req.Header["Authorization"] = []string{"token"} req.Header["Authorization"] = []string{"token"}
createdAt := time.Now() createdAt := time.Now()
var deletedAt time.Time var deletedAt time.Time
admin := &linkedca.Admin{ adm := &linkedca.Admin{
Id: "adminID", Id: "adminID",
AuthorityId: "authorityID", AuthorityId: "authorityID",
Subject: "admin", Subject: "admin",
@ -164,20 +164,15 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
auth := &mockAdminAuthority{ auth := &mockAdminAuthority{
MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) { MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) {
assert.Equals(t, "token", token) assert.Equals(t, "token", token)
return admin, nil return adm, nil
}, },
} }
next := func(w http.ResponseWriter, r *http.Request) { next := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin adm := linkedca.AdminFromContext(ctx) // verifying that the context now has a linkedca.Admin
adm, ok := a.(*linkedca.Admin)
if !ok {
t.Errorf("expected *linkedca.Admin; got %T", a)
return
}
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
if !cmp.Equal(admin, adm, opts...) { if !cmp.Equal(adm, adm, opts...) {
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...)) t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...))
} }
w.Write(nil) // mock response with status 200 w.Write(nil) // mock response with status 200
} }

View file

@ -0,0 +1,326 @@
package api
import (
"net/http"
"github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
)
type policyAdminResponderInterface interface {
GetAuthorityPolicy(w http.ResponseWriter, r *http.Request)
CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request)
GetProvisionerPolicy(w http.ResponseWriter, r *http.Request)
CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request)
GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
}
// PolicyAdminResponder is responsible for writing ACME admin responses
type PolicyAdminResponder struct {
auth adminAuthority
adminDB admin.DB
}
// NewACMEAdminResponder returns a new ACMEAdminResponder
func NewPolicyAdminResponder(auth adminAuthority, adminDB admin.DB) *PolicyAdminResponder {
return &PolicyAdminResponder{
auth: auth,
adminDB: adminDB,
}
}
// GetAuthorityPolicy handles the GET /admin/authority/policy request
func (par *PolicyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
policy, err := par.auth.GetAuthorityPolicy(r.Context())
if ae, ok := err.(*admin.Error); ok {
if !ae.IsType(admin.ErrorNotFoundType) {
api.WriteError(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
return
}
}
if policy == nil {
api.JSONNotFound(w)
return
}
api.ProtoJSONStatus(w, policy, http.StatusOK)
}
// CreateAuthorityPolicy handles the POST /admin/authority/policy request
func (par *PolicyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
policy, err := par.auth.GetAuthorityPolicy(ctx)
shouldWriteError := false
if ae, ok := err.(*admin.Error); ok {
shouldWriteError = !ae.IsType(admin.ErrorNotFoundType)
}
if shouldWriteError {
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
return
}
if policy != nil {
adminErr := admin.NewError(admin.ErrorBadRequestType, "authority already has a policy")
adminErr.Status = http.StatusConflict
api.WriteError(w, adminErr)
return
}
var newPolicy = new(linkedca.Policy)
if !api.ReadProtoJSONWithCheck(w, r.Body, newPolicy) {
return
}
adm := linkedca.AdminFromContext(ctx)
var createdPolicy *linkedca.Policy
if createdPolicy, err = par.auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
return
}
api.JSONStatus(w, createdPolicy, http.StatusCreated)
}
// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request
func (par *PolicyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
policy, err := par.auth.GetAuthorityPolicy(ctx)
shouldWriteError := false
if ae, ok := err.(*admin.Error); ok {
shouldWriteError = !ae.IsType(admin.ErrorNotFoundType)
}
if shouldWriteError {
api.WriteError(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
return
}
if policy == nil {
api.JSONNotFound(w)
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
api.WriteError(w, err)
return
}
adm := linkedca.AdminFromContext(ctx)
var updatedPolicy *linkedca.Policy
if updatedPolicy, err = par.auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
return
}
api.ProtoJSONStatus(w, updatedPolicy, http.StatusOK)
}
// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request
func (par *PolicyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
policy, err := par.auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok {
if !ae.IsType(admin.ErrorNotFoundType) {
api.WriteError(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
return
}
}
if policy == nil {
api.JSONNotFound(w)
return
}
err = par.auth.RemoveAuthorityPolicy(ctx)
if err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error deleting authority policy"))
return
}
api.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
}
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
// TODO: move getting provisioner to middleware?
ctx := r.Context()
name := chi.URLParam(r, "name")
var (
p provisioner.Interface
err error
)
if p, err = par.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
prov, err := par.adminDB.GetProvisioner(ctx, p.GetID())
if err != nil {
api.WriteError(w, err)
return
}
policy := prov.GetPolicy()
if policy == nil {
api.JSONNotFound(w)
return
}
api.ProtoJSONStatus(w, policy, http.StatusOK)
}
// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
name := chi.URLParam(r, "name")
var (
p provisioner.Interface
err error
)
if p, err = par.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
prov, err := par.adminDB.GetProvisioner(ctx, p.GetID())
if err != nil {
api.WriteError(w, err)
return
}
policy := prov.GetPolicy()
if policy != nil {
adminErr := admin.NewError(admin.ErrorBadRequestType, "provisioner %s already has a policy", name)
adminErr.Status = http.StatusConflict
api.WriteError(w, adminErr)
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
api.WriteError(w, err)
return
}
prov.Policy = newPolicy
err = par.auth.UpdateProvisioner(ctx, prov)
if err != nil {
api.WriteError(w, err)
return
}
api.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
}
// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
name := chi.URLParam(r, "name")
var (
p provisioner.Interface
err error
)
if p, err = par.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
prov, err := par.adminDB.GetProvisioner(ctx, p.GetID())
if err != nil {
api.WriteError(w, err)
return
}
var policy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, policy); err != nil {
api.WriteError(w, err)
return
}
prov.Policy = policy
err = par.auth.UpdateProvisioner(ctx, prov)
if err != nil {
api.WriteError(w, err)
return
}
api.ProtoJSONStatus(w, policy, http.StatusOK)
}
// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request
func (par *PolicyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
name := chi.URLParam(r, "name")
var (
p provisioner.Interface
err error
)
if p, err = par.auth.LoadProvisionerByName(name); err != nil {
api.WriteError(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
prov, err := par.adminDB.GetProvisioner(ctx, p.GetID())
if err != nil {
api.WriteError(w, err)
return
}
if prov.Policy == nil {
api.JSONNotFound(w)
return
}
// remove the policy
prov.Policy = nil
err = par.auth.UpdateProvisioner(ctx, prov)
if err != nil {
api.WriteError(w, err)
return
}
api.JSON(w, &DeleteResponse{Status: "ok"})
}
func (par *PolicyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
api.JSON(w, "ok")
}
func (par *PolicyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
api.JSON(w, "ok")
}
func (par *PolicyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
api.JSON(w, "ok")
}
func (par *PolicyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
api.JSON(w, "ok")
}

View file

@ -4,12 +4,14 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/linkedca"
) )
// GetProvisionersResponse is the type for GET /admin/provisioners responses. // GetProvisionersResponse is the type for GET /admin/provisioners responses.
@ -72,7 +74,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
// CreateProvisioner creates a new prov. // CreateProvisioner creates a new prov.
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
var prov = new(linkedca.Provisioner) var prov = new(linkedca.Provisioner)
if err := api.ReadProtoJSON(r.Body, prov); err != nil { if err := read.ProtoJSON(r.Body, prov); err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
@ -122,7 +124,7 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
// UpdateProvisioner updates an existing prov. // UpdateProvisioner updates an existing prov.
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
var nu = new(linkedca.Provisioner) var nu = new(linkedca.Provisioner)
if err := api.ReadProtoJSON(r.Body, nu); err != nil { if err := read.ProtoJSON(r.Body, nu); err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }

View file

@ -0,0 +1,10 @@
package admin
// ContextKey is the key type for storing and searching for
// Admin API objects in request contexts.
type ContextKey string
const (
// AdminContextKey account key
AdminContextKey = ContextKey("admin")
)

View file

@ -69,6 +69,11 @@ type DB interface {
GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error)
UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error
DeleteAdmin(ctx context.Context, id string) error DeleteAdmin(ctx context.Context, id string) error
CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
DeleteAuthorityPolicy(ctx context.Context) error
} }
// MockDB is an implementation of the DB interface that should only be used as // MockDB is an implementation of the DB interface that should only be used as
@ -86,6 +91,11 @@ type MockDB struct {
MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error
MockDeleteAdmin func(ctx context.Context, id string) error MockDeleteAdmin func(ctx context.Context, id string) error
MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
MockDeleteAuthorityPolicy func(ctx context.Context) error
MockError error MockError error
MockRet1 interface{} MockRet1 interface{}
} }
@ -179,3 +189,30 @@ func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error {
} }
return m.MockError return m.MockError
} }
func (m *MockDB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
if m.MockCreateAuthorityPolicy != nil {
return m.MockCreateAuthorityPolicy(ctx, policy)
}
return m.MockError
}
func (m *MockDB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
if m.MockGetAuthorityPolicy != nil {
return m.MockGetAuthorityPolicy(ctx)
}
return m.MockRet1.(*linkedca.Policy), m.MockError
}
func (m *MockDB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
if m.MockUpdateAuthorityPolicy != nil {
return m.MockUpdateAuthorityPolicy(ctx, policy)
}
return m.MockError
}
func (m *MockDB) DeleteAuthorityPolicy(ctx context.Context) error {
if m.MockDeleteAuthorityPolicy != nil {
return m.MockDeleteAuthorityPolicy(ctx)
}
return m.MockError
}

View file

@ -11,8 +11,9 @@ import (
) )
var ( var (
adminsTable = []byte("admins") adminsTable = []byte("admins")
provisionersTable = []byte("provisioners") provisionersTable = []byte("provisioners")
authorityPoliciesTable = []byte("authority_policies")
) )
// DB is a struct that implements the AdminDB interface. // DB is a struct that implements the AdminDB interface.
@ -23,7 +24,7 @@ type DB struct {
// New configures and returns a new Authority DB backend implemented using a nosql DB. // New configures and returns a new Authority DB backend implemented using a nosql DB.
func New(db nosqlDB.DB, authorityID string) (*DB, error) { func New(db nosqlDB.DB, authorityID string) (*DB, error) {
tables := [][]byte{adminsTable, provisionersTable} tables := [][]byte{adminsTable, provisionersTable, authorityPoliciesTable}
for _, b := range tables { for _, b := range tables {
if err := db.CreateTable(b); err != nil { if err := db.CreateTable(b); err != nil {
return nil, errors.Wrapf(err, "error creating table %s", return nil, errors.Wrapf(err, "error creating table %s",

View file

@ -0,0 +1,144 @@
package nosql
import (
"context"
"encoding/json"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/nosql"
"go.step.sm/linkedca"
)
type dbAuthorityPolicy struct {
ID string `json:"id"`
AuthorityID string `json:"authorityID"`
Policy *linkedca.Policy `json:"policy"`
}
func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy {
return dbap.Policy
}
func (dbap *dbAuthorityPolicy) clone() *dbAuthorityPolicy {
u := *dbap
return &u
}
func (db *DB) getDBAuthorityPolicyBytes(ctx context.Context, authorityID string) ([]byte, error) {
data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID))
if nosql.IsErrNotFound(err) {
return nil, admin.NewError(admin.ErrorNotFoundType, "policy %s not found", authorityID)
} else if err != nil {
return nil, errors.Wrapf(err, "error loading admin %s", authorityID)
}
return data, nil
}
func (db *DB) unmarshalDBAuthorityPolicy(data []byte, authorityID string) (*dbAuthorityPolicy, error) {
var dba = new(dbAuthorityPolicy)
if err := json.Unmarshal(data, dba); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling admin %s into dbAdmin", authorityID)
}
// if !dba.DeletedAt.IsZero() {
// return nil, admin.NewError(admin.ErrorDeletedType, "admin %s is deleted", authorityID)
// }
if dba.AuthorityID != db.authorityID {
return nil, admin.NewError(admin.ErrorAuthorityMismatchType,
"admin %s is not owned by authority %s", dba.ID, db.authorityID)
}
return dba, nil
}
func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*dbAuthorityPolicy, error) {
data, err := db.getDBAuthorityPolicyBytes(ctx, authorityID)
if err != nil {
return nil, err
}
dbap, err := db.unmarshalDBAuthorityPolicy(data, authorityID)
if err != nil {
return nil, err
}
return dbap, nil
}
// func (db *DB) unmarshalAuthorityPolicy(data []byte, authorityID string) (*linkedca.Policy, error) {
// dbap, err := db.unmarshalDBAuthorityPolicy(data, authorityID)
// if err != nil {
// return nil, err
// }
// return dbap.convert(), nil
// }
func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
dbap := &dbAuthorityPolicy{
ID: db.authorityID,
AuthorityID: db.authorityID,
Policy: policy,
}
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return err
}
return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable)
}
func (db *DB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
// policy := &linkedca.Policy{
// X509: &linkedca.X509Policy{
// Allow: &linkedca.X509Names{
// Dns: []string{".localhost"},
// },
// Deny: &linkedca.X509Names{
// Dns: []string{"denied.localhost"},
// },
// },
// Ssh: &linkedca.SSHPolicy{
// User: &linkedca.SSHUserPolicy{
// Allow: &linkedca.SSHUserNames{},
// Deny: &linkedca.SSHUserNames{},
// },
// Host: &linkedca.SSHHostPolicy{
// Allow: &linkedca.SSHHostNames{},
// Deny: &linkedca.SSHHostNames{},
// },
// },
// }
dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return nil, err
}
return dbap.convert(), nil
}
func (db *DB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return err
}
dbap := &dbAuthorityPolicy{
ID: db.authorityID,
AuthorityID: db.authorityID,
Policy: policy,
}
return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable)
}
func (db *DB) DeleteAuthorityPolicy(ctx context.Context) error {
dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return err
}
old := dbap.clone()
dbap.Policy = nil
return db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable)
}

View file

@ -19,6 +19,7 @@ type dbProvisioner struct {
Type linkedca.Provisioner_Type `json:"type"` Type linkedca.Provisioner_Type `json:"type"`
Name string `json:"name"` Name string `json:"name"`
Claims *linkedca.Claims `json:"claims"` Claims *linkedca.Claims `json:"claims"`
Policy *linkedca.Policy `json:"policy"`
Details []byte `json:"details"` Details []byte `json:"details"`
X509Template *linkedca.Template `json:"x509Template"` X509Template *linkedca.Template `json:"x509Template"`
SSHTemplate *linkedca.Template `json:"sshTemplate"` SSHTemplate *linkedca.Template `json:"sshTemplate"`
@ -43,6 +44,7 @@ func (dbp *dbProvisioner) convert2linkedca() (*linkedca.Provisioner, error) {
Type: dbp.Type, Type: dbp.Type,
Name: dbp.Name, Name: dbp.Name,
Claims: dbp.Claims, Claims: dbp.Claims,
Policy: dbp.Policy,
Details: details, Details: details,
X509Template: dbp.X509Template, X509Template: dbp.X509Template,
SshTemplate: dbp.SSHTemplate, SshTemplate: dbp.SSHTemplate,
@ -160,6 +162,7 @@ func (db *DB) CreateProvisioner(ctx context.Context, prov *linkedca.Provisioner)
Type: prov.Type, Type: prov.Type,
Name: prov.Name, Name: prov.Name,
Claims: prov.Claims, Claims: prov.Claims,
Policy: prov.Policy,
Details: details, Details: details,
X509Template: prov.X509Template, X509Template: prov.X509Template,
SSHTemplate: prov.SshTemplate, SSHTemplate: prov.SshTemplate,
@ -187,6 +190,7 @@ func (db *DB) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner)
} }
nu.Name = prov.Name nu.Name = prov.Name
nu.Claims = prov.Claims nu.Claims = prov.Claims
nu.Policy = prov.Policy
nu.Details, err = json.Marshal(prov.Details.GetData()) nu.Details, err = json.Marshal(prov.Details.GetData())
if err != nil { if err != nil {
return admin.WrapErrorISE(err, "error marshaling details when updating provisioner %s", prov.Name) return admin.WrapErrorISE(err, "error marshaling details when updating provisioner %s", prov.Name)

View file

@ -78,7 +78,7 @@ func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool
} }
// Store adds an admin to the collection and enforces the uniqueness of // Store adds an admin to the collection and enforces the uniqueness of
// admin IDs and amdin subject <-> provisioner name combos. // admin IDs and admin subject <-> provisioner name combos.
func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error { func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error {
// Input validation. // Input validation.
if adm.ProvisionerId != prov.GetID() { if adm.ProvisionerId != prov.GetID() {

View file

@ -16,6 +16,7 @@ import (
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql" adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
"github.com/smallstep/certificates/authority/administrator" "github.com/smallstep/certificates/authority/administrator"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/cas" "github.com/smallstep/certificates/cas"
casapi "github.com/smallstep/certificates/cas/apiv1" casapi "github.com/smallstep/certificates/cas/apiv1"
@ -70,10 +71,17 @@ type Authority struct {
startTime time.Time startTime time.Time
// Custom functions // Custom functions
sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error)
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)
getIdentityFunc provisioner.GetIdentityFunc getIdentityFunc provisioner.GetIdentityFunc
authorizeRenewFunc provisioner.AuthorizeRenewFunc
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
// Policy engines
x509Policy policy.X509Policy
sshUserPolicy policy.UserPolicy
sshHostPolicy policy.HostPolicy
adminMutex sync.RWMutex adminMutex sync.RWMutex
} }
@ -199,6 +207,51 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
a.provisioners = provClxn a.provisioners = provClxn
a.config.AuthorityConfig.Admins = adminList a.config.AuthorityConfig.Admins = adminList
a.admins = adminClxn a.admins = adminClxn
return nil
}
// reloadPolicyEngines reloads x509 and SSH policy engines using
// configuration stored in the DB or from the configuration file.
func (a *Authority) reloadPolicyEngines(ctx context.Context) error {
var (
err error
policyOptions *policy.Options
)
if a.config.AuthorityConfig.EnableAdmin {
linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx)
if err != nil {
return admin.WrapErrorISE(err, "error getting policy to (re)load policy engines")
}
policyOptions = policyToCertificates(linkedPolicy)
} else {
policyOptions = a.config.AuthorityConfig.Policy
}
// if no new or updated policy option is set, clear policy engines that (may have)
// been configured before and return early
if policyOptions == nil {
a.x509Policy = nil
a.sshHostPolicy = nil
a.sshUserPolicy = nil
return nil
}
// Initialize the x509 allow/deny policy engine
if a.x509Policy, err = policy.NewX509PolicyEngine(policyOptions.GetX509Options()); err != nil {
return err
}
// // Initialize the SSH allow/deny policy engine for host certificates
if a.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(policyOptions.GetSSHOptions()); err != nil {
return err
}
// // Initialize the SSH allow/deny policy engine for user certificates
if a.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(policyOptions.GetSSHOptions()); err != nil {
return err
}
return nil return nil
} }
@ -527,6 +580,11 @@ func (a *Authority) init() error {
return err return err
} }
// Load x509 and SSH Policy Engines
if err := a.reloadPolicyEngines(context.Background()); err != nil {
return err
}
// Configure templates, currently only ssh templates are supported. // Configure templates, currently only ssh templates are supported.
if a.sshCAHostCertSignKey != nil || a.sshCAUserCertSignKey != nil { if a.sshCAHostCertSignKey != nil || a.sshCAUserCertSignKey != nil {
a.templates = a.config.Templates a.templates = a.config.Templates

View file

@ -6,6 +6,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -276,6 +277,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
func (a *Authority) authorizeRenew(cert *x509.Certificate) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
serial := cert.SerialNumber.String() serial := cert.SerialNumber.String()
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
isRevoked, err := a.IsRevoked(serial) isRevoked, err := a.IsRevoked(serial)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
@ -283,7 +285,6 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
if isRevoked { if isRevoked {
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...) return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
} }
p, ok := a.provisioners.LoadByCertificate(cert) p, ok := a.provisioners.LoadByCertificate(cert)
if !ok { if !ok {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
@ -371,3 +372,80 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error
} }
return nil return nil
} }
// AuthorizeRenewToken validates the renew token and returns the leaf
// certificate in the x5cInsecure header.
func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
var claims jose.Claims
jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs)
if err != nil {
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
}
leaf := chain[0][0]
if err := jwt.Claims(leaf.PublicKey, &claims); err != nil {
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token"))
}
p, ok := a.provisioners.LoadByCertificate(leaf)
if !ok {
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
}
if err := a.UseToken(ott, p); err != nil {
return nil, err
}
if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: p.GetName(),
Subject: leaf.Subject.CommonName,
Time: time.Now().UTC(),
}, time.Minute); err != nil {
switch err {
case jose.ErrInvalidIssuer:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid issuer claim (iss)"))
case jose.ErrInvalidSubject:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid subject claim (sub)"))
case jose.ErrNotValidYet:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token not valid yet (nbf)"))
case jose.ErrExpired:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token is expired (exp)"))
case jose.ErrIssuedInTheFuture:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token issued in the future (iat)"))
default:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
}
}
audiences := a.config.GetAudiences().Renew
if !matchesAudience(claims.Audience, audiences) {
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
}
return leaf, nil
}
// matchesAudience returns true if A and B share at least one element.
func matchesAudience(as, bs []string) bool {
if len(bs) == 0 || len(as) == 0 {
return false
}
for _, b := range bs {
for _, a := range as {
if b == a || stripPort(a) == stripPort(b) {
return true
}
}
}
return false
}
// stripPort attempts to strip the port from the given url. If parsing the url
// produces errors it will just return the passed argument.
func stripPort(rawurl string) string {
u, err := url.Parse(rawurl)
if err != nil {
return rawurl
}
u.Host = u.Hostname()
return u.String()
}

View file

@ -3,11 +3,15 @@ package authority
import ( import (
"context" "context"
"crypto" "crypto"
"crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"reflect"
"strconv" "strconv"
"testing" "testing"
"time" "time"
@ -20,6 +24,7 @@ import (
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -753,6 +758,7 @@ func TestAuthority_Authorize(t *testing.T) {
func TestAuthority_authorizeRenew(t *testing.T) { func TestAuthority_authorizeRenew(t *testing.T) {
fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt") fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt")
fooCrt.NotAfter = time.Now().Add(time.Hour)
assert.FatalError(t, err) assert.FatalError(t, err)
renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt") renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt")
@ -822,7 +828,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
cert: renewDisabledCrt, cert: renewDisabledCrt,
err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"), err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -909,6 +915,7 @@ func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner.
} }
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
now := time.Now()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -917,6 +924,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if cert.ValidAfter == 0 {
cert.ValidAfter = uint64(now.Unix())
}
if cert.ValidBefore == 0 {
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
}
if err := cert.SignCert(rand.Reader, signer); err != nil { if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -1003,6 +1016,23 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
} }
func TestAuthority_authorizeSSHRenew(t *testing.T) { func TestAuthority_authorizeSSHRenew(t *testing.T) {
now := time.Now().UTC()
sshpop := func(a *Authority) (*ssh.Certificate, string) {
p, ok := a.provisioners.Load("sshpop/sshpop")
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
assert.FatalError(t, err)
token, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return cert, token
}
a := testAuthority(t) a := testAuthority(t)
jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
@ -1012,8 +1042,6 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err) assert.FatalError(t, err)
now := time.Now().UTC()
validIssuer := "step-cli" validIssuer := "step-cli"
type authorizeTest struct { type authorizeTest struct {
@ -1050,27 +1078,34 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
"fail/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
return errs.Forbidden("forbidden")
}))
_, token := sshpop(aa)
return &authorizeTest{
auth: aa,
token: token,
err: errors.New("authority.authorizeSSHRenew: forbidden"),
code: http.StatusForbidden,
}
},
"ok": func(t *testing.T) *authorizeTest { "ok": func(t *testing.T) *authorizeTest {
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") cert, token := sshpop(a)
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
assert.FatalError(t, err)
p, ok := a.provisioners.Load("sshpop/sshpop")
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop",
[]string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: tok, token: token,
cert: cert,
}
},
"ok/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
return nil
}))
cert, token := sshpop(aa)
return &authorizeTest{
auth: aa,
token: token,
cert: cert, cert: cert,
} }
}, },
@ -1290,3 +1325,283 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
}) })
} }
} }
func TestAuthority_AuthorizeRenewToken(t *testing.T) {
ctx := context.Background()
type stepProvisionerASN1 struct {
Type int
Name []byte
CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"`
}
_, signer, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
csr, err := x509util.CreateCertificateRequest("test.example.com", []string{"test.example.com"}, signer)
if err != nil {
t.Fatal(err)
}
_, otherSigner, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) {
chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
if err != nil {
t.Fatal(err)
}
var x5c []string
for _, c := range chain {
x5c = append(x5c, base64.StdEncoding.EncodeToString(c.Raw))
}
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("x5cInsecure", x5c)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: key}, so)
if err != nil {
t.Fatal(err)
}
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
if err != nil {
t.Fatal(err)
}
return s, chain[0]
}
now := time.Now()
a1 := testAuthority(t)
t1, c1 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
t2, c2 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now.Add(-time.Hour)
cert.NotAfter = now.Add(-time.Minute)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badIssuer, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "bad-issuer",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSubject, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "bad-subject",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badAudience, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
type args struct {
ctx context.Context
ott string
}
tests := []struct {
name string
authority *Authority
args args
want *x509.Certificate
wantErr bool
}{
{"ok", a1, args{ctx, t1}, c1, false},
{"ok expired cert", a1, args{ctx, t2}, c2, false},
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
{"fail token reuse", a1, args{ctx, t1}, nil, true},
{"fail token signature", a1, args{ctx, badSigner}, nil, true},
{"fail token provisioner", a1, args{ctx, badProvisioner}, nil, true},
{"fail token iss", a1, args{ctx, badIssuer}, nil, true},
{"fail token sub", a1, args{ctx, badSubject}, nil, true},
{"fail token iat", a1, args{ctx, badNotBefore}, nil, true},
{"fail token iat", a1, args{ctx, badExpiry}, nil, true},
{"fail token iat", a1, args{ctx, badIssuedAt}, nil, true},
{"fail token aud", a1, args{ctx, badAudience}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.authority.AuthorizeRenewToken(tt.args.ctx, tt.args.ott)
if (err != nil) != tt.wantErr {
t.Errorf("Authority.AuthorizeRenewToken() error = %+v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.AuthorizeRenewToken() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
cas "github.com/smallstep/certificates/cas/apiv1" cas "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
@ -26,23 +27,27 @@ var (
DefaultBackdate = time.Minute DefaultBackdate = time.Minute
// DefaultDisableRenewal disables renewals per provisioner. // DefaultDisableRenewal disables renewals per provisioner.
DefaultDisableRenewal = false DefaultDisableRenewal = false
// DefaultAllowRenewAfterExpiry allows renewals even if the certificate is
// expired.
DefaultAllowRenewAfterExpiry = false
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally // DefaultEnableSSHCA enable SSH CA features per provisioner or globally
// for all provisioners. // for all provisioners.
DefaultEnableSSHCA = false DefaultEnableSSHCA = false
// GlobalProvisionerClaims default claims for the Authority. Can be overridden // GlobalProvisionerClaims default claims for the Authority. Can be overridden
// by provisioner specific claims. // by provisioner specific claims.
GlobalProvisionerClaims = provisioner.Claims{ GlobalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &DefaultDisableRenewal, MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, EnableSSHCA: &DefaultEnableSSHCA,
EnableSSHCA: &DefaultEnableSSHCA, DisableRenewal: &DefaultDisableRenewal,
AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry,
} }
) )
@ -90,6 +95,7 @@ type AuthConfig struct {
Admins []*linkedca.Admin `json:"-"` Admins []*linkedca.Admin `json:"-"`
Template *ASN1DN `json:"template,omitempty"` Template *ASN1DN `json:"template,omitempty"`
Claims *provisioner.Claims `json:"claims,omitempty"` Claims *provisioner.Claims `json:"claims,omitempty"`
Policy *policy.Options `json:"policy,omitempty"`
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
Backdate *provisioner.Duration `json:"backdate,omitempty"` Backdate *provisioner.Duration `json:"backdate,omitempty"`
EnableAdmin bool `json:"enableAdmin,omitempty"` EnableAdmin bool `json:"enableAdmin,omitempty"`
@ -269,28 +275,32 @@ func (c *Config) GetAudiences() provisioner.Audiences {
} }
for _, name := range c.DNSNames { for _, name := range c.DNSNames {
hostname := toHostname(name)
audiences.Sign = append(audiences.Sign, audiences.Sign = append(audiences.Sign,
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/sign", hostname),
fmt.Sprintf("https://%s/sign", toHostname(name)), fmt.Sprintf("https://%s/sign", hostname),
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
fmt.Sprintf("https://%s/ssh/sign", toHostname(name))) fmt.Sprintf("https://%s/ssh/sign", hostname))
audiences.Renew = append(audiences.Renew,
fmt.Sprintf("https://%s/1.0/renew", hostname),
fmt.Sprintf("https://%s/renew", hostname))
audiences.Revoke = append(audiences.Revoke, audiences.Revoke = append(audiences.Revoke,
fmt.Sprintf("https://%s/1.0/revoke", toHostname(name)), fmt.Sprintf("https://%s/1.0/revoke", hostname),
fmt.Sprintf("https://%s/revoke", toHostname(name))) fmt.Sprintf("https://%s/revoke", hostname))
audiences.SSHSign = append(audiences.SSHSign, audiences.SSHSign = append(audiences.SSHSign,
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
fmt.Sprintf("https://%s/ssh/sign", toHostname(name)), fmt.Sprintf("https://%s/ssh/sign", hostname),
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/sign", hostname),
fmt.Sprintf("https://%s/sign", toHostname(name))) fmt.Sprintf("https://%s/sign", hostname))
audiences.SSHRevoke = append(audiences.SSHRevoke, audiences.SSHRevoke = append(audiences.SSHRevoke,
fmt.Sprintf("https://%s/1.0/ssh/revoke", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/revoke", hostname),
fmt.Sprintf("https://%s/ssh/revoke", toHostname(name))) fmt.Sprintf("https://%s/ssh/revoke", hostname))
audiences.SSHRenew = append(audiences.SSHRenew, audiences.SSHRenew = append(audiences.SSHRenew,
fmt.Sprintf("https://%s/1.0/ssh/renew", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/renew", hostname),
fmt.Sprintf("https://%s/ssh/renew", toHostname(name))) fmt.Sprintf("https://%s/ssh/renew", hostname))
audiences.SSHRekey = append(audiences.SSHRekey, audiences.SSHRekey = append(audiences.SSHRekey,
fmt.Sprintf("https://%s/1.0/ssh/rekey", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/rekey", hostname),
fmt.Sprintf("https://%s/ssh/rekey", toHostname(name))) fmt.Sprintf("https://%s/ssh/rekey", hostname))
} }
return audiences return audiences

View file

@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
@ -34,6 +35,9 @@ type linkedCaClient struct {
authorityID string authorityID string
} }
// interface guard
var _ admin.DB = (*linkedCaClient)(nil)
type linkedCAClaims struct { type linkedCAClaims struct {
jose.Claims jose.Claims
SANs []string `json:"sans"` SANs []string `json:"sans"`
@ -310,6 +314,22 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
} }
func (c *linkedCaClient) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
return errors.New("not implemented yet")
}
func (c *linkedCaClient) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
return nil, errors.New("not implemented yet")
}
func (c *linkedCaClient) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
return errors.New("not implemented yet")
}
func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error {
return errors.New("not implemented yet")
}
func serializeCertificate(crt *x509.Certificate) string { func serializeCertificate(crt *x509.Certificate) string {
if crt == nil { if crt == nil {
return "" return ""

View file

@ -92,6 +92,24 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e
} }
} }
// WithAuthorizeRenewFunc sets a custom function that authorizes the renewal of
// an X.509 certificate.
func WithAuthorizeRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error) Option {
return func(a *Authority) error {
a.authorizeRenewFunc = fn
return nil
}
}
// WithAuthorizeSSHRenewFunc sets a custom function that authorizes the renewal
// of a SSH certificate.
func WithAuthorizeSSHRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error) Option {
return func(a *Authority) error {
a.authorizeSSHRenewFunc = fn
return nil
}
}
// WithSSHBastionFunc sets a custom function to get the bastion for a // WithSSHBastionFunc sets a custom function to get the bastion for a
// given user-host pair. // given user-host pair.
func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option { func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option {

195
authority/policy.go Normal file
View file

@ -0,0 +1,195 @@
package authority
import (
"context"
"fmt"
"github.com/pkg/errors"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/admin"
authPolicy "github.com/smallstep/certificates/authority/policy"
policy "github.com/smallstep/certificates/policy"
)
func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
p, err := a.adminDB.GetAuthorityPolicy(ctx)
if err != nil {
return nil, err
}
return p, nil
}
func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.checkPolicy(ctx, adm, p); err != nil {
return nil, err
}
if err := a.adminDB.CreateAuthorityPolicy(ctx, p); err != nil {
return nil, err
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return nil, admin.WrapErrorISE(err, "error reloading policy engines when creating authority policy")
}
return p, nil // TODO: return the newly stored policy
}
func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.checkPolicy(ctx, adm, p); err != nil {
return nil, err
}
if err := a.adminDB.UpdateAuthorityPolicy(ctx, p); err != nil {
return nil, err
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return nil, admin.WrapErrorISE(err, "error reloading policy engines when updating authority policy")
}
return p, nil // TODO: return the updated stored policy
}
func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.adminDB.DeleteAuthorityPolicy(ctx); err != nil {
return err
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return admin.WrapErrorISE(err, "error reloading policy engines when deleting authority policy")
}
return nil
}
func (a *Authority) checkPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) error {
// convert the policy; return early if nil
policyOptions := policyToCertificates(p)
if policyOptions == nil {
return nil
}
engine, err := authPolicy.NewX509PolicyEngine(policyOptions.GetX509Options())
if err != nil {
return admin.WrapErrorISE(err, "error creating temporary policy engine")
}
// TODO(hs): Provide option to force the policy, even when the admin subject would be locked out?
sans := []string{adm.Subject}
if err := isAllowed(engine, sans); err != nil {
return err
}
// TODO(hs): perform the check for other admin subjects too?
// What logic to use for that: do all admins need access? Only super admins? At least one?
return nil
}
func isAllowed(engine authPolicy.X509Policy, sans []string) error {
var (
allowed bool
err error
)
if allowed, err = engine.AreSANsAllowed(sans); err != nil {
var policyErr *policy.NamePolicyError
if isPolicyErr := errors.As(err, &policyErr); isPolicyErr && policyErr.Reason == policy.NotAuthorizedForThisName {
return fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans)
}
return err
}
if !allowed {
return fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans)
}
return nil
}
func policyToCertificates(p *linkedca.Policy) *authPolicy.Options {
// return early
if p == nil {
return nil
}
// prepare full policy struct
opts := &authPolicy.Options{
X509: &authPolicy.X509PolicyOptions{
AllowedNames: &authPolicy.X509NameOptions{},
DeniedNames: &authPolicy.X509NameOptions{},
},
SSH: &authPolicy.SSHPolicyOptions{
Host: &authPolicy.SSHHostCertificateOptions{
AllowedNames: &authPolicy.SSHNameOptions{},
DeniedNames: &authPolicy.SSHNameOptions{},
},
User: &authPolicy.SSHUserCertificateOptions{
AllowedNames: &authPolicy.SSHNameOptions{},
DeniedNames: &authPolicy.SSHNameOptions{},
},
},
}
// fill x509 policy configuration
if p.X509 != nil {
if p.X509.Allow != nil {
opts.X509.AllowedNames.DNSDomains = p.X509.Allow.Dns
opts.X509.AllowedNames.IPRanges = p.X509.Allow.Ips
opts.X509.AllowedNames.EmailAddresses = p.X509.Allow.Emails
opts.X509.AllowedNames.URIDomains = p.X509.Allow.Uris
}
if p.X509.Deny != nil {
opts.X509.DeniedNames.DNSDomains = p.X509.Deny.Dns
opts.X509.DeniedNames.IPRanges = p.X509.Deny.Ips
opts.X509.DeniedNames.EmailAddresses = p.X509.Deny.Emails
opts.X509.DeniedNames.URIDomains = p.X509.Deny.Uris
}
}
// fill ssh policy configuration
if p.Ssh != nil {
if p.Ssh.Host != nil {
if p.Ssh.Host.Allow != nil {
opts.SSH.Host.AllowedNames.DNSDomains = p.Ssh.Host.Allow.Dns
opts.SSH.Host.AllowedNames.IPRanges = p.Ssh.Host.Allow.Ips
opts.SSH.Host.AllowedNames.EmailAddresses = p.Ssh.Host.Allow.Principals
}
if p.Ssh.Host.Deny != nil {
opts.SSH.Host.DeniedNames.DNSDomains = p.Ssh.Host.Deny.Dns
opts.SSH.Host.DeniedNames.IPRanges = p.Ssh.Host.Deny.Ips
opts.SSH.Host.DeniedNames.Principals = p.Ssh.Host.Deny.Principals
}
}
if p.Ssh.User != nil {
if p.Ssh.User.Allow != nil {
opts.SSH.User.AllowedNames.EmailAddresses = p.Ssh.User.Allow.Emails
opts.SSH.User.AllowedNames.Principals = p.Ssh.User.Allow.Principals
}
if p.Ssh.User.Deny != nil {
opts.SSH.User.DeniedNames.EmailAddresses = p.Ssh.User.Deny.Emails
opts.SSH.User.DeniedNames.Principals = p.Ssh.User.Deny.Principals
}
}
}
return opts
}

188
authority/policy/options.go Normal file
View file

@ -0,0 +1,188 @@
package policy
// Options is a container for authority level x509 and SSH
// policy configuration.
type Options struct {
X509 *X509PolicyOptions `json:"x509,omitempty"`
SSH *SSHPolicyOptions `json:"ssh,omitempty"`
}
// GetX509Options returns the x509 authority level policy
// configuration
func (o *Options) GetX509Options() *X509PolicyOptions {
if o == nil {
return nil
}
return o.X509
}
// GetSSHOptions returns the SSH authority level policy
// configuration
func (o *Options) GetSSHOptions() *SSHPolicyOptions {
if o == nil {
return nil
}
return o.SSH
}
// X509PolicyOptionsInterface is an interface for providers
// of x509 allowed and denied names.
type X509PolicyOptionsInterface interface {
GetAllowedNameOptions() *X509NameOptions
GetDeniedNameOptions() *X509NameOptions
}
// X509PolicyOptions is a container for x509 allowed and denied
// names.
type X509PolicyOptions struct {
// AllowedNames contains the x509 allowed names
AllowedNames *X509NameOptions `json:"allow,omitempty"`
// DeniedNames contains the x509 denied names
DeniedNames *X509NameOptions `json:"deny,omitempty"`
}
// X509NameOptions models the X509 name policy configuration.
type X509NameOptions struct {
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
URIDomains []string `json:"uri,omitempty"`
}
// HasNames checks if the AllowedNameOptions has one or more
// names configured.
func (o *X509NameOptions) HasNames() bool {
return len(o.DNSDomains) > 0 ||
len(o.IPRanges) > 0 ||
len(o.EmailAddresses) > 0 ||
len(o.URIDomains) > 0
}
// SSHPolicyOptionsInterface is an interface for providers of
// SSH user and host name policy configuration.
type SSHPolicyOptionsInterface interface {
GetAllowedUserNameOptions() *SSHNameOptions
GetDeniedUserNameOptions() *SSHNameOptions
GetAllowedHostNameOptions() *SSHNameOptions
GetDeniedHostNameOptions() *SSHNameOptions
}
// SSHPolicyOptions is a container for SSH user and host policy
// configuration
type SSHPolicyOptions struct {
// User contains SSH user certificate options.
User *SSHUserCertificateOptions `json:"user,omitempty"`
// Host contains SSH host certificate options.
Host *SSHHostCertificateOptions `json:"host,omitempty"`
}
// GetAllowedNameOptions returns x509 allowed name policy configuration
func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions {
if o == nil {
return nil
}
return o.AllowedNames
}
// GetDeniedNameOptions returns the x509 denied name policy configuration
func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions {
if o == nil {
return nil
}
return o.DeniedNames
}
// GetAllowedUserNameOptions returns the SSH allowed user name policy
// configuration.
func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
if o.User == nil {
return nil
}
return o.User.AllowedNames
}
// GetDeniedUserNameOptions returns the SSH denied user name policy
// configuration.
func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
if o.User == nil {
return nil
}
return o.User.DeniedNames
}
// GetAllowedHostNameOptions returns the SSH allowed host name policy
// configuration.
func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
if o.Host == nil {
return nil
}
return o.Host.AllowedNames
}
// GetDeniedHostNameOptions returns the SSH denied host name policy
// configuration.
func (o *SSHPolicyOptions) GetDeniedHostNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
if o.Host == nil {
return nil
}
return o.Host.DeniedNames
}
// SSHUserCertificateOptions is a collection of SSH user certificate options.
type SSHUserCertificateOptions struct {
// AllowedNames contains the names the provisioner is authorized to sign
AllowedNames *SSHNameOptions `json:"allow,omitempty"`
// DeniedNames contains the names the provisioner is not authorized to sign
DeniedNames *SSHNameOptions `json:"deny,omitempty"`
}
// SSHHostCertificateOptions is a collection of SSH host certificate options.
// It's an alias of SSHUserCertificateOptions, as the options are the same
// for both types of certificates.
type SSHHostCertificateOptions SSHUserCertificateOptions
// SSHNameOptions models the SSH name policy configuration.
type SSHNameOptions struct {
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
Principals []string `json:"principal,omitempty"`
}
// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the
// names that a provisioner is authorized to sign SSH certificates for.
func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
return o.AllowedNames
}
// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the
// names that a provisioner is NOT authorized to sign SSH certificates for.
func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
return o.DeniedNames
}
// HasNames checks if the SSHNameOptions has one or more
// names configured.
func (o *SSHNameOptions) HasNames() bool {
return len(o.DNSDomains) > 0 ||
len(o.EmailAddresses) > 0 ||
len(o.Principals) > 0
}

134
authority/policy/policy.go Normal file
View file

@ -0,0 +1,134 @@
package policy
import (
"fmt"
"github.com/smallstep/certificates/policy"
)
// X509Policy is an alias for policy.X509NamePolicyEngine
type X509Policy policy.X509NamePolicyEngine
// UserPolicy is an alias for policy.SSHNamePolicyEngine
type UserPolicy policy.SSHNamePolicyEngine
// HostPolicy is an alias for policy.SSHNamePolicyEngine
type HostPolicy policy.SSHNamePolicyEngine
// NewX509PolicyEngine creates a new x509 name policy engine
func NewX509PolicyEngine(policyOptions X509PolicyOptionsInterface) (X509Policy, error) {
// return early if no policy engine options to configure
if policyOptions == nil {
return nil, nil
}
options := []policy.NamePolicyOption{}
allowed := policyOptions.GetAllowedNameOptions()
if allowed != nil && allowed.HasNames() {
options = append(options,
policy.WithPermittedDNSDomains(allowed.DNSDomains),
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges),
policy.WithPermittedEmailAddresses(allowed.EmailAddresses),
policy.WithPermittedURIDomains(allowed.URIDomains),
)
}
denied := policyOptions.GetDeniedNameOptions()
if denied != nil && denied.HasNames() {
options = append(options,
policy.WithExcludedDNSDomains(denied.DNSDomains),
policy.WithExcludedIPsOrCIDRs(denied.IPRanges),
policy.WithExcludedEmailAddresses(denied.EmailAddresses),
policy.WithExcludedURIDomains(denied.URIDomains),
)
}
// ensure no policy engine is returned when no name options were provided
if len(options) == 0 {
return nil, nil
}
// enable x509 Subject Common Name validation by default
options = append(options, policy.WithSubjectCommonNameVerification())
return policy.New(options...)
}
type sshPolicyEngineType string
const (
UserPolicyEngineType sshPolicyEngineType = "user"
HostPolicyEngineType sshPolicyEngineType = "host"
)
// newSSHUserPolicyEngine creates a new SSH user certificate policy engine
func NewSSHUserPolicyEngine(policyOptions SSHPolicyOptionsInterface) (UserPolicy, error) {
policyEngine, err := newSSHPolicyEngine(policyOptions, UserPolicyEngineType)
if err != nil {
return nil, err
}
return policyEngine, nil
}
// newSSHHostPolicyEngine create a new SSH host certificate policy engine
func NewSSHHostPolicyEngine(policyOptions SSHPolicyOptionsInterface) (HostPolicy, error) {
policyEngine, err := newSSHPolicyEngine(policyOptions, HostPolicyEngineType)
if err != nil {
return nil, err
}
return policyEngine, nil
}
// newSSHPolicyEngine creates a new SSH name policy engine
func newSSHPolicyEngine(policyOptions SSHPolicyOptionsInterface, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) {
// return early if no policy engine options to configure
if policyOptions == nil {
return nil, nil
}
var (
allowed *SSHNameOptions
denied *SSHNameOptions
)
switch typ {
case UserPolicyEngineType:
allowed = policyOptions.GetAllowedUserNameOptions()
denied = policyOptions.GetDeniedUserNameOptions()
case HostPolicyEngineType:
allowed = policyOptions.GetAllowedHostNameOptions()
denied = policyOptions.GetDeniedHostNameOptions()
default:
return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ)
}
options := []policy.NamePolicyOption{}
if allowed != nil && allowed.HasNames() {
options = append(options,
policy.WithPermittedDNSDomains(allowed.DNSDomains),
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges),
policy.WithPermittedEmailAddresses(allowed.EmailAddresses),
policy.WithPermittedPrincipals(allowed.Principals),
)
}
if denied != nil && denied.HasNames() {
options = append(options,
policy.WithExcludedDNSDomains(denied.DNSDomains),
policy.WithExcludedIPsOrCIDRs(denied.IPRanges),
policy.WithExcludedEmailAddresses(denied.EmailAddresses),
policy.WithExcludedPrincipals(denied.Principals),
)
}
// ensure no policy engine is returned when no name options were provided
if len(options) == 0 {
return nil, nil
}
return policy.New(options...)
}

View file

@ -7,8 +7,8 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/policy" "github.com/smallstep/certificates/authority/policy"
) )
// ACME is the acme provisioner type, an entity that can authorize the ACME // ACME is the acme provisioner type, an entity that can authorize the ACME
@ -26,8 +26,9 @@ type ACME struct {
RequireEAB bool `json:"requireEAB,omitempty"` RequireEAB bool `json:"requireEAB,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
x509Policy policy.X509NamePolicyEngine ctl *Controller
x509Policy policy.X509Policy
} }
// GetID returns the provisioner unique identifier. // GetID returns the provisioner unique identifier.
@ -72,7 +73,7 @@ func (p *ACME) GetOptions() *Options {
// DefaultTLSCertDuration returns the default TLS cert duration enforced by // DefaultTLSCertDuration returns the default TLS cert duration enforced by
// the provisioner. // the provisioner.
func (p *ACME) DefaultTLSCertDuration() time.Duration { func (p *ACME) DefaultTLSCertDuration() time.Duration {
return p.claimer.DefaultTLSCertDuration() return p.ctl.Claimer.DefaultTLSCertDuration()
} }
// Init initializes and validates the fields of an ACME type. // Init initializes and validates the fields of an ACME type.
@ -84,19 +85,13 @@ func (p *ACME) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty") return errors.New("provisioner name cannot be empty")
} }
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
// TODO(hs): ensure no race conditions happen when reloading settings and requesting certs? if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
// TODO(hs): implement memoization strategy, so that reloading is not required when no changes were made to allow/deny?
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
} }
return nil p.ctl, err = NewController(p, p.Claims, config)
return
} }
// ACMEIdentifierType encodes ACME Identifier types // ACMEIdentifierType encodes ACME Identifier types
@ -115,20 +110,22 @@ type ACMEIdentifier struct {
Value string Value string
} }
// AuthorizeOrderIdentifiers verifies the provisioner is authorized to issue a // AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a
// certificate for the Identifiers provided in an Order. // certificate for an ACME Order Identifier.
func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier string) error { func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier ACMEIdentifier) error {
// identifier is allowed if no policy is configured
if p.x509Policy == nil { if p.x509Policy == nil {
return nil return nil
} }
// assuming only valid identifiers (IP or DNS) are provided // assuming only valid identifiers (IP or DNS) are provided
var err error var err error
if ip := net.ParseIP(identifier); ip != nil { switch identifier.Type {
_, err = p.x509Policy.IsIPAllowed(ip) case IP:
} else { _, err = p.x509Policy.IsIPAllowed(net.ParseIP(identifier.Value))
_, err = p.x509Policy.IsDNSAllowed(identifier) case DNS:
_, err = p.x509Policy.IsDNSAllowed(identifier.Value)
} }
return err return err
@ -142,10 +139,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeACME, p.Name, ""), newProvisionerExtensionOption(TypeACME, p.Name, ""),
newForceCNOption(p.ForceCN), newForceCNOption(p.ForceCN),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
} }
@ -166,8 +163,5 @@ func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error {
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName())
}
return nil
} }

View file

@ -91,6 +91,7 @@ func TestACME_Init(t *testing.T) {
} }
func TestACME_AuthorizeRenew(t *testing.T) { func TestACME_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct { type test struct {
p *ACME p *ACME
cert *x509.Certificate cert *x509.Certificate
@ -104,21 +105,27 @@ func TestACME_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()), err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateACME() p, err := generateACME()
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
} }
}, },
} }
@ -172,18 +179,18 @@ func TestACME_AuthorizeSign(t *testing.T) {
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeACME)) assert.Equals(t, v.Type, TypeACME)
assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "") assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case *forceCNOption: case *forceCNOption:
assert.Equals(t, v.ForceCN, tc.p.ForceCN) assert.Equals(t, v.ForceCN, tc.p.ForceCN)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine) assert.Equals(t, nil, v.policyEngine)
default: default:

View file

@ -17,6 +17,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
@ -264,11 +265,10 @@ type AWS struct {
IIDRoots string `json:"iidRoots,omitempty"` IIDRoots string `json:"iidRoots,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
config *awsConfig config *awsConfig
audiences Audiences ctl *Controller
x509Policy x509PolicyEngine x509Policy policy.X509Policy
sshHostPolicy *hostPolicyEngine sshHostPolicy policy.HostPolicy
} }
// GetID returns the provisioner unique identifier. // GetID returns the provisioner unique identifier.
@ -402,15 +402,11 @@ func (p *AWS) Init(config Config) (err error) {
case p.InstanceAge.Value() < 0: case p.InstanceAge.Value() < 0:
return errors.New("provisioner instanceAge cannot be negative") return errors.New("provisioner instanceAge cannot be negative")
} }
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Add default config // Add default config
if p.config, err = newAWSConfig(p.IIDRoots); err != nil { if p.config, err = newAWSConfig(p.IIDRoots); err != nil {
return err return err
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
// validate IMDS versions // validate IMDS versions
if len(p.IMDSVersions) == 0 { if len(p.IMDSVersions) == 0 {
@ -428,16 +424,18 @@ func (p *AWS) Init(config Config) (err error) {
} }
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for host certificates // Initialize the SSH allow/deny policy engine for host certificates
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
return err return err
} }
return nil config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return
} }
// AuthorizeSign validates the given token and returns the sign options that // AuthorizeSign validates the given token and returns the sign options that
@ -485,11 +483,11 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID), newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
commonNameValidator(payload.Claims.Subject), commonNameValidator(payload.Claims.Subject),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
), nil ), nil
} }
@ -499,10 +497,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName())
}
return nil
} }
// assertConfig initializes the config if it has not been initialized // assertConfig initializes the config if it has not been initialized
@ -561,7 +556,7 @@ func (p *AWS) readURL(url string) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return nil, fmt.Errorf("Request for metadata returned non-successful status code %d", return nil, fmt.Errorf("request for metadata returned non-successful status code %d",
resp.StatusCode) resp.StatusCode)
} }
@ -594,7 +589,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode) return nil, fmt.Errorf("request for API token returned non-successful status code %d", resp.StatusCode)
} }
token, err := io.ReadAll(resp.Body) token, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
@ -677,7 +672,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(payload.Audience, p.audiences.Sign) { if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) {
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)") return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
} }
@ -717,7 +712,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
@ -765,11 +760,11 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -677,18 +677,18 @@ func TestAWS_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAWS)) assert.Equals(t, v.Type, TypeAWS)
assert.Equals(t, v.Name, tt.aws.GetName()) assert.Equals(t, v.Name, tt.aws.GetName())
assert.Equals(t, v.CredentialID, tt.aws.Accounts[0]) assert.Equals(t, v.CredentialID, tt.aws.Accounts[0])
assert.Len(t, 2, v.KeyValuePairs) assert.Len(t, 2, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator: case commonNameValidator:
assert.Equals(t, string(v), tt.args.cn) assert.Equals(t, string(v), tt.args.cn)
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator: case ipAddressesValidator:
assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")}) assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")})
case emailAddressesValidator: case emailAddressesValidator:
@ -728,7 +728,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p3.Claims = &Claims{EnableSSHCA: &disable} p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com") t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
@ -749,7 +749,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{ expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}, CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
@ -826,6 +826,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
} }
func TestAWS_AuthorizeRenew(t *testing.T) { func TestAWS_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateAWS() p1, err := generateAWS()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateAWS() p2, err := generateAWS()
@ -834,7 +835,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -847,8 +848,14 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -13,10 +13,13 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
) )
// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens. // azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
@ -30,7 +33,7 @@ const azureDefaultAudience = "https://management.azure.com/"
// azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim. // azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim.
// Using case insensitive as resourceGroups appears as resourcegroups. // Using case insensitive as resourceGroups appears as resourcegroups.
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Compute/virtualMachines/([^/]+)$`) var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
type azureConfig struct { type azureConfig struct {
oidcDiscoveryURL string oidcDiscoveryURL string
@ -96,12 +99,12 @@ type Azure struct {
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"` DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
config *azureConfig config *azureConfig
oidcConfig openIDConfiguration oidcConfig openIDConfiguration
keyStore *keyStore keyStore *keyStore
x509Policy x509PolicyEngine ctl *Controller
sshHostPolicy *hostPolicyEngine x509Policy policy.X509Policy
sshHostPolicy policy.HostPolicy
} }
// GetID returns the provisioner unique identifier. // GetID returns the provisioner unique identifier.
@ -205,37 +208,34 @@ func (p *Azure) Init(config Config) (err error) {
case p.Audience == "": // use default audience case p.Audience == "": // use default audience
p.Audience = azureDefaultAudience p.Audience = azureDefaultAudience
} }
// Initialize config // Initialize config
p.assertConfig() p.assertConfig()
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Decode and validate openid-configuration endpoint // Decode and validate openid-configuration endpoint
if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
return err return
} }
if err := p.oidcConfig.Validate(); err != nil { if err := p.oidcConfig.Validate(); err != nil {
return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL) return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL)
} }
// Get JWK key set // Get JWK key set
if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil { if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil {
return err return
} }
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for host certificates // Initialize the SSH allow/deny policy engine for host certificates
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
return err return err
} }
return nil p.ctl, err = NewController(p, p.Claims, config)
return
} }
// authorizeToken returns the claims, name, group, subscription, identityObjectID, error. // authorizeToken returns the claims, name, group, subscription, identityObjectID, error.
@ -275,11 +275,14 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, str
} }
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
if len(re) != 4 { if len(re) != 5 {
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID) return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
} }
var subscription, group, name string
identityObjectID := claims.ObjectID identityObjectID := claims.ObjectID
subscription, group, name := re[1], re[2], re[3] subscription, group, name = re[1], re[2], re[4]
return &claims, name, group, subscription, identityObjectID, nil return &claims, name, group, subscription, identityObjectID, nil
} }
@ -367,10 +370,10 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID), newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
), nil ), nil
} }
@ -380,15 +383,12 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName())
} }
@ -433,11 +433,11 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -95,7 +95,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
@ -237,7 +237,7 @@ func TestAzure_authorizeToken(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), jwk) time.Now(), jwk)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -252,7 +252,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -267,7 +267,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
"foo", "subscriptionID", "resourceGroup", "virtualMachine", "foo", "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -321,7 +321,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -437,28 +437,28 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience", failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0]) time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0]) time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), badKey) time.Now(), badKey)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -506,18 +506,18 @@ func TestAzure_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAzure)) assert.Equals(t, v.Type, TypeAzure)
assert.Equals(t, v.Name, tt.azure.GetName()) assert.Equals(t, v.Name, tt.azure.GetName())
assert.Equals(t, v.CredentialID, tt.azure.TenantID) assert.Equals(t, v.CredentialID, tt.azure.TenantID)
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.azure.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator: case commonNameValidator:
assert.Equals(t, string(v), "virtualMachine") assert.Equals(t, string(v), "virtualMachine")
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.azure.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.azure.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator: case ipAddressesValidator:
assert.Equals(t, v, nil) assert.Equals(t, v, nil)
case emailAddressesValidator: case emailAddressesValidator:
@ -538,6 +538,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
} }
func TestAzure_AuthorizeRenew(t *testing.T) { func TestAzure_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateAzure() p1, err := generateAzure()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateAzure() p2, err := generateAzure()
@ -546,7 +547,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -559,8 +560,14 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -597,7 +604,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p3.Claims = &Claims{EnableSSHCA: &disable} p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("subject", "caURL") t1, err := p1.GetIdentityToken("subject", "caURL")
@ -618,7 +625,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{ expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"virtualMachine"}, CertType: "host", Principals: []string{"virtualMachine"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),

View file

@ -10,10 +10,10 @@ import (
// Claims so that individual provisioners can override global claims. // Claims so that individual provisioners can override global claims.
type Claims struct { type Claims struct {
// TLS CA properties // TLS CA properties
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
DisableRenewal *bool `json:"disableRenewal,omitempty"`
// SSH CA properties // SSH CA properties
MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"` MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"`
MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"` MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"`
@ -22,6 +22,10 @@ type Claims struct {
MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"` MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"`
DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"` DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"`
EnableSSHCA *bool `json:"enableSSHCA,omitempty"` EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
// Renewal properties
DisableRenewal *bool `json:"disableRenewal,omitempty"`
AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"`
} }
// Claimer is the type that controls claims. It provides an interface around the // Claimer is the type that controls claims. It provides an interface around the
@ -40,19 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
// Claims returns the merge of the inner and global claims. // Claims returns the merge of the inner and global claims.
func (c *Claimer) Claims() Claims { func (c *Claimer) Claims() Claims {
disableRenewal := c.IsDisableRenewal() disableRenewal := c.IsDisableRenewal()
allowRenewAfterExpiry := c.AllowRenewAfterExpiry()
enableSSHCA := c.IsSSHCAEnabled() enableSSHCA := c.IsSSHCAEnabled()
return Claims{ return Claims{
MinTLSDur: &Duration{c.MinTLSCertDuration()}, MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
DisableRenewal: &disableRenewal, MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, EnableSSHCA: &enableSSHCA,
EnableSSHCA: &enableSSHCA, DisableRenewal: &disableRenewal,
AllowRenewAfterExpiry: &allowRenewAfterExpiry,
} }
} }
@ -102,6 +109,16 @@ func (c *Claimer) IsDisableRenewal() bool {
return *c.claims.DisableRenewal return *c.claims.DisableRenewal
} }
// AllowRenewAfterExpiry returns if the renewal flow is authorized if the
// certificate is expired. If the property is not set within the provisioner
// then the global value from the authority configuration will be used.
func (c *Claimer) AllowRenewAfterExpiry() bool {
if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil {
return *c.global.AllowRenewAfterExpiry
}
return *c.claims.AllowRenewAfterExpiry
}
// DefaultSSHCertDuration returns the default SSH certificate duration for the // DefaultSSHCertDuration returns the default SSH certificate duration for the
// given certificate type. // given certificate type.
func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) { func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) {

View file

@ -152,8 +152,8 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
// proper id to load the provisioner. // proper id to load the provisioner.
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) { func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) {
for _, e := range cert.Extensions { for _, e := range cert.Extensions {
if e.Id.Equal(stepOIDProvisioner) { if e.Id.Equal(StepOIDProvisioner) {
var provisioner stepProvisionerASN1 var provisioner extensionASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return nil, false return nil, false
} }

View file

@ -147,6 +147,17 @@ func TestCollection_LoadByToken(t *testing.T) {
} }
func TestCollection_LoadByCertificate(t *testing.T) { func TestCollection_LoadByCertificate(t *testing.T) {
mustExtension := func(typ Type, name, credentialID string) pkix.Extension {
e := Extension{
Type: typ, Name: name, CredentialID: credentialID,
}
ext, err := e.ToExtension()
if err != nil {
t.Fatal(err)
}
return ext
}
p1, err := generateJWK() p1, err := generateJWK()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateOIDC() p2, err := generateOIDC()
@ -159,30 +170,21 @@ func TestCollection_LoadByCertificate(t *testing.T) {
byName.Store(p2.GetName(), p2) byName.Store(p2.GetName(), p2)
byName.Store(p3.GetName(), p3) byName.Store(p3.GetName(), p3)
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
assert.FatalError(t, err)
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
assert.FatalError(t, err)
ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "")
assert.FatalError(t, err)
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
assert.FatalError(t, err)
ok1Cert := &x509.Certificate{ ok1Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok1Ext}, Extensions: []pkix.Extension{mustExtension(1, p1.Name, p1.Key.KeyID)},
} }
ok2Cert := &x509.Certificate{ ok2Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok2Ext}, Extensions: []pkix.Extension{mustExtension(2, p2.Name, p2.ClientID)},
} }
ok3Cert := &x509.Certificate{ ok3Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok3Ext}, Extensions: []pkix.Extension{mustExtension(TypeACME, p3.Name, "")},
} }
notFoundCert := &x509.Certificate{ notFoundCert := &x509.Certificate{
Extensions: []pkix.Extension{notFoundExt}, Extensions: []pkix.Extension{mustExtension(1, "foo", "bar")},
} }
badCert := &x509.Certificate{ badCert := &x509.Certificate{
Extensions: []pkix.Extension{ Extensions: []pkix.Extension{
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")}, {Id: StepOIDProvisioner, Critical: false, Value: []byte("foobar")},
}, },
} }

View file

@ -0,0 +1,194 @@
package provisioner
import (
"context"
"crypto/x509"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
)
// Controller wraps a provisioner with other attributes useful in callback
// functions.
type Controller struct {
Interface
Audiences *Audiences
Claimer *Claimer
IdentityFunc GetIdentityFunc
AuthorizeRenewFunc AuthorizeRenewFunc
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
}
// NewController initializes a new provisioner controller.
func NewController(p Interface, claims *Claims, config Config) (*Controller, error) {
claimer, err := NewClaimer(claims, config.Claims)
if err != nil {
return nil, err
}
return &Controller{
Interface: p,
Audiences: &config.Audiences,
Claimer: claimer,
IdentityFunc: config.GetIdentityFunc,
AuthorizeRenewFunc: config.AuthorizeRenewFunc,
AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc,
}, nil
}
// GetIdentity returns the identity for a given email.
func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) {
if c.IdentityFunc != nil {
return c.IdentityFunc(ctx, c.Interface, email)
}
return DefaultIdentityFunc(ctx, c.Interface, email)
}
// AuthorizeRenew returns nil if the given cert can be renewed, returns an error
// otherwise.
func (c *Controller) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if c.AuthorizeRenewFunc != nil {
return c.AuthorizeRenewFunc(ctx, c, cert)
}
return DefaultAuthorizeRenew(ctx, c, cert)
}
// AuthorizeSSHRenew returns nil if the given cert can be renewed, returns an
// error otherwise.
func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificate) error {
if c.AuthorizeSSHRenewFunc != nil {
return c.AuthorizeSSHRenewFunc(ctx, c, cert)
}
return DefaultAuthorizeSSHRenew(ctx, c, cert)
}
// Identity is the type representing an externally supplied identity that is used
// by provisioners to populate certificate fields.
type Identity struct {
Usernames []string `json:"usernames"`
Permissions `json:"permissions"`
}
// GetIdentityFunc is a function that returns an identity.
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
// AuthorizeRenewFunc is a function that returns nil if the renewal of a
// certificate is enabled.
type AuthorizeRenewFunc func(ctx context.Context, p *Controller, cert *x509.Certificate) error
// AuthorizeSSHRenewFunc is a function that returns nil if the renewal of the
// given SSH certificate is enabled.
type AuthorizeSSHRenewFunc func(ctx context.Context, p *Controller, cert *ssh.Certificate) error
// DefaultIdentityFunc return a default identity depending on the provisioner
// type. For OIDC email is always present and the usernames might
// contain empty strings.
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
switch k := p.(type) {
case *OIDC:
// OIDC principals would be:
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
// 2. Sanitized local.
// 3. Raw local (if different).
// 4. Email address.
name := SanitizeSSHUserPrincipal(email)
if !sshUserRegex.MatchString(name) {
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
}
usernames := []string{name}
if i := strings.LastIndex(email, "@"); i >= 0 {
usernames = append(usernames, email[:i])
}
usernames = append(usernames, email)
return &Identity{
Usernames: SanitizeStringSlices(usernames),
}, nil
default:
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
}
}
// DefaultAuthorizeRenew is the default implementation of AuthorizeRenew. It
// will return an error if the provisioner has the renewal disabled, if the
// certificate is not yet valid or if the certificate is expired and renew after
// expiry is disabled.
func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certificate) error {
if p.Claimer.IsDisableRenewal() {
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
}
now := time.Now().Truncate(time.Second)
if now.Before(cert.NotBefore) {
return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano))
}
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}
return nil
}
// DefaultAuthorizeSSHRenew is the default implementation of AuthorizeSSHRenew. It
// will return an error if the provisioner has the renewal disabled, if the
// certificate is not yet valid or if the certificate is expired and renew after
// expiry is disabled.
func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
if p.Claimer.IsDisableRenewal() {
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
}
unixNow := time.Now().Unix()
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
return errs.Unauthorized("certificate is not yet valid")
}
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}
return nil
}
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
// SanitizeStringSlices removes duplicated an empty strings.
func SanitizeStringSlices(original []string) []string {
output := []string{}
seen := make(map[string]struct{})
for _, entry := range original {
if entry == "" {
continue
}
if _, value := seen[entry]; !value {
seen[entry] = struct{}{}
output = append(output, entry)
}
}
return output
}
// SanitizeSSHUserPrincipal grabs an email or a string with the format
// local@domain and returns a sanitized version of the local, valid to be used
// as a user name. If the email starts with a letter between a and z, the
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
func SanitizeSSHUserPrincipal(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i]
}
return strings.Map(func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= '0' && r <= '9':
return r
case r == '-':
return '-'
case r == '.': // drop dots
return -1
default:
return '_'
}
}, strings.ToLower(email))
}

View file

@ -0,0 +1,391 @@
package provisioner
import (
"context"
"crypto/x509"
"fmt"
"reflect"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
var trueValue = true
func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer {
t.Helper()
c, err := NewClaimer(claims, global)
if err != nil {
t.Fatal(err)
}
return c
}
func mustDuration(t *testing.T, s string) *Duration {
t.Helper()
d, err := NewDuration(s)
if err != nil {
t.Fatal(err)
}
return d
}
func TestNewController(t *testing.T) {
type args struct {
p Interface
claims *Claims
config Config
}
tests := []struct {
name string
args args
want *Controller
wantErr bool
}{
{"ok", args{&JWK{}, nil, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, false},
{"ok with claims", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, globalProvisionerClaims),
}, false},
{"fail claimer", args{&JWK{}, &Claims{
MinTLSDur: mustDuration(t, "24h"),
MaxTLSDur: mustDuration(t, "2h"),
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
if (err != nil) != tt.wantErr {
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewController() = %v, want %v", got, tt.want)
}
})
}
}
func TestController_GetIdentity(t *testing.T) {
ctx := context.Background()
type fields struct {
Interface Interface
IdentityFunc GetIdentityFunc
}
type args struct {
ctx context.Context
email string
}
tests := []struct {
name string
fields fields
args args
want *Identity
wantErr bool
}{
{"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{
Usernames: []string{"jane", "jane@doe.org"},
}, false},
{"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
return &Identity{Usernames: []string{"jane"}}, nil
}}, args{ctx, "jane@doe.org"}, &Identity{
Usernames: []string{"jane"},
}, false},
{"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true},
{"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
return nil, fmt.Errorf("an error")
}}, args{ctx, "jane@doe.org"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
IdentityFunc: tt.fields.IdentityFunc,
}
got, err := c.GetIdentity(tt.args.ctx, tt.args.email)
if (err != nil) != tt.wantErr {
t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want)
}
})
}
}
func TestController_AuthorizeRenew(t *testing.T) {
ctx := context.Background()
now := time.Now().Truncate(time.Second)
type fields struct {
Interface Interface
Claimer *Claimer
AuthorizeRenewFunc AuthorizeRenewFunc
}
type args struct {
ctx context.Context
cert *x509.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return nil
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return nil
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, false},
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(time.Hour),
NotAfter: now.Add(2 * time.Hour),
}}, true},
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, true},
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return fmt.Errorf("an error")
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
Claimer: tt.fields.Claimer,
AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc,
}
if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestController_AuthorizeSSHRenew(t *testing.T) {
ctx := context.Background()
now := time.Now()
type fields struct {
Interface Interface
Claimer *Claimer
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
}
type args struct {
ctx context.Context
cert *ssh.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return nil
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return nil
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, false},
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(time.Hour).Unix()),
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
}}, true},
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, true},
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return fmt.Errorf("an error")
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
Claimer: tt.fields.Claimer,
AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc,
}
if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestDefaultAuthorizeRenew(t *testing.T) {
ctx := context.Background()
now := time.Now().Truncate(time.Second)
type args struct {
ctx context.Context
p *Controller
cert *x509.Certificate
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, false},
{"fail disabled", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
{"fail not yet valid", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(time.Hour),
NotAfter: now.Add(2 * time.Hour),
}}, true},
{"fail expired", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestDefaultAuthorizeSSHRenew(t *testing.T) {
ctx := context.Background()
now := time.Now()
type args struct {
ctx context.Context
p *Controller
cert *ssh.Certificate
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, false},
{"fail disabled", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
{"fail not yet valid", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(time.Hour).Unix()),
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
}}, true},
{"fail expired", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -0,0 +1,73 @@
package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
)
var (
// StepOIDRoot is the root OID for smallstep.
StepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
// StepOIDProvisioner is the OID for the provisioner extension.
StepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(StepOIDRoot, 1)...)
)
// Extension is the Go representation of the provisioner extension.
type Extension struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
type extensionASN1 struct {
Type int
Name []byte
CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"`
}
// Marshal marshals the extension using encoding/asn1.
func (e *Extension) Marshal() ([]byte, error) {
return asn1.Marshal(extensionASN1{
Type: int(e.Type),
Name: []byte(e.Name),
CredentialID: []byte(e.CredentialID),
KeyValuePairs: e.KeyValuePairs,
})
}
// ToExtension returns the pkix.Extension representation of the provisioner
// extension.
func (e *Extension) ToExtension() (pkix.Extension, error) {
b, err := e.Marshal()
if err != nil {
return pkix.Extension{}, err
}
return pkix.Extension{
Id: StepOIDProvisioner,
Value: b,
}, nil
}
// GetProvisionerExtension goes through all the certificate extensions and
// returns the provisioner extension (1.3.6.1.4.1.37476.9000.64.1).
func GetProvisionerExtension(cert *x509.Certificate) (*Extension, bool) {
for _, e := range cert.Extensions {
if e.Id.Equal(StepOIDProvisioner) {
var provisioner extensionASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return nil, false
}
return &Extension{
Type: Type(provisioner.Type),
Name: string(provisioner.Name),
CredentialID: string(provisioner.CredentialID),
KeyValuePairs: provisioner.KeyValuePairs,
}, true
}
}
return nil, false
}

View file

@ -0,0 +1,158 @@
package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"reflect"
"testing"
"go.step.sm/crypto/pemutil"
)
func TestExtension_Marshal(t *testing.T) {
type fields struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
tests := []struct {
name string
fields fields
want []byte
wantErr bool
}{
{"ok", fields{TypeJWK, "name", "credentialID", nil}, []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
}, false},
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, []byte{
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
0x13, 0x03, 0x62, 0x61, 0x72,
}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Extension{
Type: tt.fields.Type,
Name: tt.fields.Name,
CredentialID: tt.fields.CredentialID,
KeyValuePairs: tt.fields.KeyValuePairs,
}
got, err := e.Marshal()
if (err != nil) != tt.wantErr {
t.Errorf("Extension.Marshal() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Extension.Marshal() = %x, want %v", got, tt.want)
}
})
}
}
func TestExtension_ToExtension(t *testing.T) {
type fields struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
tests := []struct {
name string
fields fields
want pkix.Extension
wantErr bool
}{
{"ok", fields{TypeJWK, "name", "credentialID", nil}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
},
}, false},
{"ok empty pairs", fields{TypeJWK, "name", "credentialID", []string{}}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
},
}, false},
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
0x13, 0x03, 0x62, 0x61, 0x72,
},
}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Extension{
Type: tt.fields.Type,
Name: tt.fields.Name,
CredentialID: tt.fields.CredentialID,
KeyValuePairs: tt.fields.KeyValuePairs,
}
got, err := e.ToExtension()
if (err != nil) != tt.wantErr {
t.Errorf("Extension.ToExtension() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Extension.ToExtension() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetProvisionerExtension(t *testing.T) {
mustCertificate := func(fn string) *x509.Certificate {
cert, err := pemutil.ReadCertificate(fn)
if err != nil {
t.Fatal(err)
}
return cert
}
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
args args
want *Extension
want1 bool
}{
{"ok", args{mustCertificate("testdata/certs/good-extension.crt")}, &Extension{
Type: TypeJWK,
Name: "mariano@smallstep.com",
CredentialID: "nvgnR8wSzpUlrt_tC3mvrhwhBx9Y7T1WL_JjcFVWYBQ",
}, true},
{"fail unmarshal", args{mustCertificate("testdata/certs/bad-extension.crt")}, nil, false},
{"missing extension", args{mustCertificate("testdata/certs/aws.crt")}, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := GetProvisionerExtension(tt.args.cert)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetProvisionerExtension() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("GetProvisionerExtension() got1 = %v, want %v", got1, tt.want1)
}
})
}
}

View file

@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
@ -88,12 +89,11 @@ type GCP struct {
InstanceAge Duration `json:"instanceAge,omitempty"` InstanceAge Duration `json:"instanceAge,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
config *gcpConfig config *gcpConfig
keyStore *keyStore keyStore *keyStore
audiences Audiences ctl *Controller
x509Policy x509PolicyEngine x509Policy policy.X509Policy
sshHostPolicy *hostPolicyEngine sshHostPolicy policy.HostPolicy
} }
// GetID returns the provisioner unique identifier. The name should uniquely // GetID returns the provisioner unique identifier. The name should uniquely
@ -196,8 +196,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
} }
// Init validates and initializes the GCP provisioner. // Init validates and initializes the GCP provisioner.
func (p *GCP) Init(config Config) error { func (p *GCP) Init(config Config) (err error) {
var err error
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -206,30 +205,28 @@ func (p *GCP) Init(config Config) error {
case p.InstanceAge.Value() < 0: case p.InstanceAge.Value() < 0:
return errors.New("provisioner instanceAge cannot be negative") return errors.New("provisioner instanceAge cannot be negative")
} }
// Initialize config // Initialize config
p.assertConfig() p.assertConfig()
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize key store // Initialize key store
p.keyStore, err = newKeyStore(p.config.CertsURL) if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil {
if err != nil { return
return err
} }
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for host certificates // Initialize the SSH allow/deny policy engine for host certificates
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
return err return err
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
return nil p.ctl, err = NewController(p, p.Claims, config)
return
} }
// AuthorizeSign validates the given token and returns the sign options that // AuthorizeSign validates the given token and returns the sign options that
@ -281,20 +278,17 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName), newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
), nil ), nil
} }
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName())
}
return nil
} }
// assertConfig initializes the config if it has not been initialized. // assertConfig initializes the config if it has not been initialized.
@ -341,7 +335,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences.Sign) { if !matchesAudience(claims.Audience, p.ctl.Audiences.Sign) {
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)") return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
} }
@ -396,7 +390,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
@ -444,11 +438,11 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -549,18 +549,18 @@ func TestGCP_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeGCP)) assert.Equals(t, v.Type, TypeGCP)
assert.Equals(t, v.Name, tt.gcp.GetName()) assert.Equals(t, v.Name, tt.gcp.GetName())
assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0]) assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0])
assert.Len(t, 4, v.KeyValuePairs) assert.Len(t, 4, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.gcp.ctl.Claimer.DefaultTLSCertDuration())
case commonNameSliceValidator: case commonNameSliceValidator:
assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.gcp.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.gcp.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator: case ipAddressesValidator:
assert.Equals(t, v, nil) assert.Equals(t, v, nil)
case emailAddressesValidator: case emailAddressesValidator:
@ -597,7 +597,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p3.Claims = &Claims{EnableSSHCA: &disable} p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0], t1, err := generateGCPToken(p1.ServiceAccounts[0],
@ -624,7 +624,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{ expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}, CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
@ -700,6 +700,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
} }
func TestGCP_AuthorizeRenew(t *testing.T) { func TestGCP_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateGCP() p1, err := generateGCP()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateGCP() p2, err := generateGCP()
@ -708,7 +709,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -721,8 +722,14 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renewal-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -7,10 +7,13 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
) )
// jwtPayload extends jwt.Claims with step attributes. // jwtPayload extends jwt.Claims with step attributes.
@ -35,11 +38,10 @@ type JWK struct {
EncryptedKey string `json:"encryptedKey,omitempty"` EncryptedKey string `json:"encryptedKey,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer ctl *Controller
audiences Audiences x509Policy policy.X509Policy
x509Policy x509PolicyEngine sshHostPolicy policy.HostPolicy
sshHostPolicy *hostPolicyEngine sshUserPolicy policy.UserPolicy
sshUserPolicy *userPolicyEngine
} }
// GetID returns the provisioner unique identifier. The name and credential id // GetID returns the provisioner unique identifier. The name and credential id
@ -101,28 +103,23 @@ func (p *JWK) Init(config Config) (err error) {
return errors.New("provisioner key cannot be empty") return errors.New("provisioner key cannot be empty")
} }
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for user certificates // Initialize the SSH allow/deny policy engine for user certificates
if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for host certificates // Initialize the SSH allow/deny policy engine for host certificates
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
return err return err
} }
p.audiences = config.Audiences p.ctl, err = NewController(p, p.Claims, config)
return err return
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -164,14 +161,14 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke) // TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
} }
@ -198,12 +195,12 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID), newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
commonNameValidator(claims.Subject), commonNameValidator(claims.Subject),
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
defaultSANsValidator(claims.SANs), defaultSANsValidator(claims.SANs),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
}, nil }, nil
} }
@ -213,12 +210,13 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { // if p.claimer.IsDisableRenewal() {
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName()) // return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName())
} // }
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew) // TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew)
//return p.authorizeRenew(cert) //return p.authorizeRenew(cert)
return nil // return nil
return p.ctl.AuthorizeRenew(ctx, cert)
} }
// func (p *JWK) authorizeRenew(cert *x509.Certificate) error { // func (p *JWK) authorizeRenew(cert *x509.Certificate) error {
@ -231,10 +229,10 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
} }
@ -291,11 +289,11 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
return append(signOptions, return append(signOptions,
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed
@ -305,7 +303,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.SSHRevoke) _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke)
// TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke) // TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
} }

View file

@ -76,13 +76,13 @@ func TestJWK_Init(t *testing.T) {
}, },
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, Claims: &Claims{DefaultTLSDur: &Duration{0}}},
err: errors.New("claims: MinTLSCertDuration must be greater than 0"), err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
} }
}, },
"ok": func(t *testing.T) ProvisionerValidateTest { "ok": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences}, p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
} }
}, },
} }
@ -300,18 +300,18 @@ func TestJWK_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeJWK)) assert.Equals(t, v.Type, TypeJWK)
assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.Name, tt.prov.GetName())
assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID) assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID)
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator: case commonNameValidator:
assert.Equals(t, string(v), "subject") assert.Equals(t, string(v), "subject")
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case defaultSANsValidator: case defaultSANsValidator:
assert.Equals(t, []string(v), tt.sans) assert.Equals(t, []string(v), tt.sans)
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
@ -327,6 +327,7 @@ func TestJWK_AuthorizeSign(t *testing.T) {
} }
func TestJWK_AuthorizeRenew(t *testing.T) { func TestJWK_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateJWK() p1, err := generateJWK()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateJWK() p2, err := generateJWK()
@ -335,7 +336,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -348,8 +349,14 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -375,7 +382,7 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p2.Claims = &Claims{EnableSSHCA: &disable} p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
jwk, err := decryptJSONWebKey(p1.EncryptedKey) jwk, err := decryptJSONWebKey(p1.EncryptedKey)
@ -404,8 +411,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration() userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{ expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name"}, CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
@ -487,8 +494,8 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
signer, err := generateJSONWebKey() signer, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration() userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{ expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name"}, CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),

View file

@ -10,11 +10,14 @@ import (
"net/http" "net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
) )
// NOTE: There can be at most one kubernetes service account provisioner configured // NOTE: There can be at most one kubernetes service account provisioner configured
@ -42,19 +45,18 @@ type k8sSAPayload struct {
// entity trusted to make signature requests. // entity trusted to make signature requests.
type K8sSA struct { type K8sSA struct {
*base *base
ID string `json:"-"` ID string `json:"-"`
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
PubKeys []byte `json:"publicKeys,omitempty"` PubKeys []byte `json:"publicKeys,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
audiences Audiences
//kauthn kauthn.AuthenticationV1Interface //kauthn kauthn.AuthenticationV1Interface
pubKeys []interface{} pubKeys []interface{}
x509Policy x509PolicyEngine ctl *Controller
sshHostPolicy *hostPolicyEngine x509Policy policy.X509Policy
sshUserPolicy *userPolicyEngine sshHostPolicy policy.HostPolicy
sshUserPolicy policy.UserPolicy
} }
// GetID returns the provisioner unique identifier. The name and credential id // GetID returns the provisioner unique identifier. The name and credential id
@ -142,28 +144,23 @@ func (p *K8sSA) Init(config Config) (err error) {
p.kauthn = k8s.AuthenticationV1() p.kauthn = k8s.AuthenticationV1()
*/ */
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for user certificates // Initialize the SSH allow/deny policy engine for user certificates
if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for host certificates // Initialize the SSH allow/deny policy engine for host certificates
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
return err return err
} }
p.audiences = config.Audiences p.ctl, err = NewController(p, p.Claims, config)
return err return
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -230,13 +227,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error { func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign")
} }
@ -259,28 +256,25 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeK8sSA, p.Name, ""), newProvisionerExtensionOption(TypeK8sSA, p.Name, ""),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
}, nil }, nil
} }
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeSSHSign validates an request for an SSH certificate. // AuthorizeSSHSign validates an request for an SSH certificate.
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
} }
@ -302,11 +296,11 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Require type, key-id and principals in the SignSSHOptions. // Require type, key-id and principals in the SignSSHOptions.
&sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true},
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -179,6 +179,7 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
} }
func TestK8sSA_AuthorizeRenew(t *testing.T) { func TestK8sSA_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct { type test struct {
p *K8sSA p *K8sSA
cert *x509.Certificate cert *x509.Certificate
@ -192,21 +193,27 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()), err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateK8sSA(nil) p, err := generateK8sSA(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
} }
}, },
} }
@ -275,16 +282,16 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeK8sSA)) assert.Equals(t, v.Type, TypeK8sSA)
assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "") assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine) assert.Equals(t, nil, v.policyEngine)
default: default:
@ -313,7 +320,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p.Claims = &Claims{EnableSSHCA: &disable} p.Claims = &Claims{EnableSSHCA: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
@ -365,11 +372,11 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
case *sshCertOptionsRequireValidator: case *sshCertOptionsRequireValidator:
assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}) assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true})
case *sshCertValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshDefaultPublicKeyValidator: case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator: case *sshCertDefaultValidator:
case *sshDefaultDuration: case *sshDefaultDuration:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshNamePolicyValidator: case *sshNamePolicyValidator:
assert.Equals(t, nil, v.userPolicyEngine) assert.Equals(t, nil, v.userPolicyEngine)
assert.Equals(t, nil, v.hostPolicyEngine) assert.Equals(t, nil, v.hostPolicyEngine)

View file

@ -10,6 +10,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
nebula "github.com/slackhq/nebula/cert" nebula "github.com/slackhq/nebula/cert"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
@ -40,16 +41,15 @@ type Nebula struct {
Roots []byte `json:"roots"` Roots []byte `json:"roots"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
caPool *nebula.NebulaCAPool caPool *nebula.NebulaCAPool
audiences Audiences ctl *Controller
x509Policy x509PolicyEngine x509Policy policy.X509Policy
sshHostPolicy *hostPolicyEngine sshHostPolicy policy.HostPolicy
sshUserPolicy *userPolicyEngine sshUserPolicy policy.UserPolicy
} }
// Init verifies and initializes the Nebula provisioner. // Init verifies and initializes the Nebula provisioner.
func (p *Nebula) Init(config Config) error { func (p *Nebula) Init(config Config) (err error) {
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -59,34 +59,29 @@ func (p *Nebula) Init(config Config) error {
return errors.New("provisioner root(s) cannot be empty") return errors.New("provisioner root(s) cannot be empty")
} }
var err error
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots) p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots)
if err != nil { if err != nil {
return errs.InternalServer("failed to create ca pool: %v", err) return errs.InternalServer("failed to create ca pool: %v", err)
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for user certificates // Initialize the SSH allow/deny policy engine for user certificates
if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for host certificates // Initialize the SSH allow/deny policy engine for host certificates
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
return err return err
} }
return nil config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return
} }
// GetID returns the provisioner id. // GetID returns the provisioner id.
@ -138,7 +133,7 @@ func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) {
// AuthorizeSign returns the list of SignOption for a Sign request. // AuthorizeSign returns the list of SignOption for a Sign request.
func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
crt, claims, err := p.authorizeToken(token, p.audiences.Sign) crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -172,7 +167,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeNebula, p.Name, ""), newProvisionerExtensionOption(TypeNebula, p.Name, ""),
profileLimitDuration{ profileLimitDuration{
def: p.claimer.DefaultTLSCertDuration(), def: p.ctl.Claimer.DefaultTLSCertDuration(),
notBefore: crt.Details.NotBefore, notBefore: crt.Details.NotBefore,
notAfter: crt.Details.NotAfter, notAfter: crt.Details.NotAfter,
}, },
@ -183,7 +178,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
IPs: crt.Details.Ips, IPs: crt.Details.Ips,
}, },
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
}, nil }, nil
} }
@ -191,11 +186,11 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
// Currently the Nebula provisioner only grants host SSH certificates. // Currently the Nebula provisioner only grants host SSH certificates.
func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
} }
crt, claims, err := p.authorizeToken(token, p.audiences.SSHSign) crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -273,11 +268,11 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
return append(signOptions, return append(signOptions,
templateOptions, templateOptions,
// Checks the validity bounds, and set the validity if has not been set. // Checks the validity bounds, and set the validity if has not been set.
&sshLimitDuration{p.claimer, crt.Details.NotAfter}, &sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter},
// Validate public key. // Validate public key.
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed
@ -287,23 +282,20 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error { func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, crt)
return errs.Unauthorized("renew is disabled for nebula provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeRevoke returns an error if the token is not valid. // AuthorizeRevoke returns an error if the token is not valid.
func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error { func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error {
return p.validateToken(token, p.audiences.Revoke) return p.validateToken(token, p.ctl.Audiences.Revoke)
} }
// AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid. // AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid.
func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
} }
if _, _, err := p.authorizeToken(token, p.audiences.SSHRevoke); err != nil { if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil {
return err return err
} }
return nil return nil

View file

@ -327,7 +327,7 @@ func TestNebula_GetIDForToken(t *testing.T) {
func TestNebula_GetTokenID(t *testing.T) { func TestNebula_GetTokenID(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer) c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer)
t1 := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv) t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv)
_, claims, err := parseToken(t1) _, claims, err := parseToken(t1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -441,8 +441,8 @@ func TestNebula_AuthorizeSign(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv)
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), nil, crt, priv) okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), nil, crt, priv)
pBadOptions, _, _ := mustNebulaProvisioner(t) pBadOptions, _, _ := mustNebulaProvisioner(t)
pBadOptions.caPool = p.caPool pBadOptions.caPool = p.caPool
@ -483,20 +483,20 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
// Ok provisioner // Ok provisioner
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "host", CertType: "host",
KeyID: "test.lan", KeyID: "test.lan",
Principals: []string{"test.lan", "10.1.0.1"}, Principals: []string{"test.lan", "10.1.0.1"},
}, crt, priv) }, crt, priv)
okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), nil, crt, priv) okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), nil, crt, priv)
okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)), ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)),
ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)), ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)),
}, crt, priv) }, crt, priv)
failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "user", CertType: "user",
}, crt, priv) }, crt, priv)
failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "host", CertType: "host",
KeyID: "test.lan", KeyID: "test.lan",
Principals: []string{"test.lan", "10.1.0.1", "foo.bar"}, Principals: []string{"test.lan", "10.1.0.1", "foo.bar"},
@ -549,6 +549,8 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
func TestNebula_AuthorizeRenew(t *testing.T) { func TestNebula_AuthorizeRenew(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
now := time.Now().Truncate(time.Second)
// Ok provisioner // Ok provisioner
p, _, _ := mustNebulaProvisioner(t) p, _, _ := mustNebulaProvisioner(t)
@ -567,8 +569,14 @@ func TestNebula_AuthorizeRenew(t *testing.T) {
args args args args
wantErr bool wantErr bool
}{ }{
{"ok", p, args{ctx, &x509.Certificate{}}, false}, {"ok", p, args{ctx, &x509.Certificate{
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{}}, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -584,12 +592,12 @@ func TestNebula_AuthorizeRevoke(t *testing.T) {
// Ok provisioner // Ok provisioner
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
// Fail different CA // Fail different CA
nc, signer := mustNebulaCA(t) nc, signer := mustNebulaCA(t)
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
failToken := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) failToken := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
type args struct { type args struct {
ctx context.Context ctx context.Context
@ -618,12 +626,12 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
// Ok provisioner // Ok provisioner
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
// Fail different CA // Fail different CA
nc, signer := mustNebulaCA(t) nc, signer := mustNebulaCA(t)
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
// Provisioner with SSH disabled // Provisioner with SSH disabled
var bFalse bool var bFalse bool
@ -657,7 +665,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
func TestNebula_AuthorizeSSHRenew(t *testing.T) { func TestNebula_AuthorizeSSHRenew(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRenew[0], now(), nil, crt, priv) t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRenew[0], now(), nil, crt, priv)
type args struct { type args struct {
ctx context.Context ctx context.Context
@ -689,7 +697,7 @@ func TestNebula_AuthorizeSSHRenew(t *testing.T) {
func TestNebula_AuthorizeSSHRekey(t *testing.T) { func TestNebula_AuthorizeSSHRekey(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRekey[0], now(), nil, crt, priv) t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRekey[0], now(), nil, crt, priv)
type args struct { type args struct {
ctx context.Context ctx context.Context
@ -726,20 +734,20 @@ func TestNebula_authorizeToken(t *testing.T) {
t1 := now() t1 := now()
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv) okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv)
okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{ okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{
CertType: "host", CertType: "host",
KeyID: "test.lan", KeyID: "test.lan",
Principals: []string{"test.lan"}, Principals: []string{"test.lan"},
}, crt, priv) }, crt, priv)
okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv) okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv)
// Token with errors // Token with errors
failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv)
failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv) failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv)
failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
// Not a nebula token // Not a nebula token
jwk, err := generateJSONWebKey() jwk, err := generateJSONWebKey()
@ -761,7 +769,7 @@ func TestNebula_authorizeToken(t *testing.T) {
IssuedAt: jose.NewNumericDate(t1), IssuedAt: jose.NewNumericDate(t1),
NotBefore: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1),
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
Audience: []string{p.audiences.Sign[0]}, Audience: []string{p.ctl.Audiences.Sign[0]},
} }
sshClaims := jose.Claims{ sshClaims := jose.Claims{
ID: "[REPLACEME]", ID: "[REPLACEME]",
@ -770,7 +778,7 @@ func TestNebula_authorizeToken(t *testing.T) {
IssuedAt: jose.NewNumericDate(t1), IssuedAt: jose.NewNumericDate(t1),
NotBefore: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1),
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
Audience: []string{p.audiences.SSHSign[0]}, Audience: []string{p.ctl.Audiences.SSHSign[0]},
} }
type args struct { type args struct {
@ -785,14 +793,14 @@ func TestNebula_authorizeToken(t *testing.T) {
want1 *jwtPayload want1 *jwtPayload
wantErr bool wantErr bool
}{ }{
{"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{ {"ok x509", p, args{ok, p.ctl.Audiences.Sign}, crt, &jwtPayload{
Claims: x509Claims, Claims: x509Claims,
SANs: []string{"10.1.0.1"}, SANs: []string{"10.1.0.1"},
}, false}, }, false},
{"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{ {"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, crt, &jwtPayload{
Claims: x509Claims, Claims: x509Claims,
}, false}, }, false},
{"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{ {"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
Claims: sshClaims, Claims: sshClaims,
Step: &stepPayload{ Step: &stepPayload{
SSH: &SignSSHOptions{ SSH: &SignSSHOptions{
@ -802,16 +810,16 @@ func TestNebula_authorizeToken(t *testing.T) {
}, },
}, },
}, false}, }, false},
{"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{ {"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
Claims: sshClaims, Claims: sshClaims,
}, false}, }, false},
{"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true}, {"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, nil, true},
{"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true}, {"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true}, {"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true}, {"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true}, {"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true}, {"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true}, {"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -12,10 +12,13 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
) )
// openIDConfiguration contains the necessary properties in the // openIDConfiguration contains the necessary properties in the
@ -92,11 +95,11 @@ type OIDC struct {
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
configuration openIDConfiguration configuration openIDConfiguration
keyStore *keyStore keyStore *keyStore
claimer *Claimer
getIdentityFunc GetIdentityFunc getIdentityFunc GetIdentityFunc
x509Policy x509PolicyEngine ctl *Controller
sshHostPolicy *hostPolicyEngine x509Policy policy.X509Policy
sshUserPolicy *userPolicyEngine sshHostPolicy policy.HostPolicy
sshUserPolicy policy.UserPolicy
} }
func sanitizeEmail(email string) string { func sanitizeEmail(email string) string {
@ -175,11 +178,6 @@ func (o *OIDC) Init(config Config) (err error) {
} }
} }
// Update claims with global ones
if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil {
return err
}
// Decode and validate openid-configuration endpoint // Decode and validate openid-configuration endpoint
u, err := url.Parse(o.ConfigurationEndpoint) u, err := url.Parse(o.ConfigurationEndpoint)
if err != nil { if err != nil {
@ -212,21 +210,22 @@ func (o *OIDC) Init(config Config) (err error) {
} }
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if o.x509Policy, err = newX509PolicyEngine(o.Options.GetX509Options()); err != nil { if o.x509Policy, err = policy.NewX509PolicyEngine(o.Options.GetX509Options()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for user certificates // Initialize the SSH allow/deny policy engine for user certificates
if o.sshUserPolicy, err = newSSHUserPolicyEngine(o.Options.GetSSHOptions()); err != nil { if o.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(o.Options.GetSSHOptions()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for host certificates // Initialize the SSH allow/deny policy engine for host certificates
if o.sshHostPolicy, err = newSSHHostPolicyEngine(o.Options.GetSSHOptions()); err != nil { if o.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(o.Options.GetSSHOptions()); err != nil {
return err return err
} }
return nil o.ctl, err = NewController(o, o.Claims, config)
return
} }
// ValidatePayload validates the given token payload. // ValidatePayload validates the given token payload.
@ -378,10 +377,10 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()), profileDefaultDuration(o.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()), newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(o.x509Policy), newX509NamePolicyValidator(o.x509Policy),
}, nil }, nil
} }
@ -391,15 +390,12 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if o.claimer.IsDisableRenewal() { return o.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName())
}
return nil
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !o.claimer.IsSSHCAEnabled() { if !o.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName()) return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName())
} }
claims, err := o.authorizeToken(token) claims, err := o.authorizeToken(token)
@ -414,7 +410,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Get the identity using either the default identityFunc or one injected // Get the identity using either the default identityFunc or one injected
// externally. Note that the PreferredUsername might be empty. // externally. Note that the PreferredUsername might be empty.
// TBD: Would preferred_username present a safety issue here? // TBD: Would preferred_username present a safety issue here?
iden, err := o.getIdentityFunc(ctx, o, claims.Email) iden, err := o.ctl.GetIdentity(ctx, claims.Email)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
} }
@ -465,11 +461,11 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
return append(signOptions, return append(signOptions,
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{o.claimer}, &sshDefaultDuration{o.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{o.claimer}, &sshCertValidityValidator{o.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -327,16 +327,16 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeOIDC)) assert.Equals(t, v.Type, TypeOIDC)
assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.Name, tt.prov.GetName())
assert.Equals(t, v.CredentialID, tt.prov.ClientID) assert.Equals(t, v.CredentialID, tt.prov.ClientID)
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case emailOnlyIdentity: case emailOnlyIdentity:
assert.Equals(t, string(v), "name@smallstep.com") assert.Equals(t, string(v), "name@smallstep.com")
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
@ -413,6 +413,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
} }
func TestOIDC_AuthorizeRenew(t *testing.T) { func TestOIDC_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateOIDC() p1, err := generateOIDC()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateOIDC() p2, err := generateOIDC()
@ -421,7 +422,7 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -434,8 +435,14 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -480,7 +487,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p6.Claims = &Claims{EnableSSHCA: &disable} p6.Claims = &Claims{EnableSSHCA: &disable}
p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) p6.ctl.Claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
// Update configuration endpoints and initialize // Update configuration endpoints and initialize
@ -496,10 +503,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
assert.FatalError(t, p4.Init(config)) assert.FatalError(t, p4.Init(config))
assert.FatalError(t, p5.Init(config)) assert.FatalError(t, p5.Init(config))
p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { p4.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
return &Identity{Usernames: []string{"max", "mariano"}}, nil return &Identity{Usernames: []string{"max", "mariano"}}, nil
} }
p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { p5.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
return nil, errors.New("force") return nil, errors.New("force")
} }
// Additional test needed for empty usernames and duplicate email and usernames // Additional test needed for empty usernames and duplicate email and usernames
@ -529,8 +536,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration() userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{ expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name", "name@smallstep.com"}, CertType: "user", Principals: []string{"name", "name@smallstep.com"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),

View file

@ -5,8 +5,11 @@ import (
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
) )
// CertificateOptions is an interface that returns a list of options passed when // CertificateOptions is an interface that returns a list of options passed when
@ -58,10 +61,10 @@ type X509Options struct {
TemplateData json.RawMessage `json:"templateData,omitempty"` TemplateData json.RawMessage `json:"templateData,omitempty"`
// AllowedNames contains the SANs the provisioner is authorized to sign // AllowedNames contains the SANs the provisioner is authorized to sign
AllowedNames *X509NameOptions `json:"allow,omitempty"` AllowedNames *policy.X509NameOptions
// DeniedNames contains the SANs the provisioner is not authorized to sign // DeniedNames contains the SANs the provisioner is not authorized to sign
DeniedNames *X509NameOptions `json:"deny,omitempty"` DeniedNames *policy.X509NameOptions
} }
// HasTemplate returns true if a template is defined in the provisioner options. // HasTemplate returns true if a template is defined in the provisioner options.
@ -69,41 +72,24 @@ func (o *X509Options) HasTemplate() bool {
return o != nil && (o.Template != "" || o.TemplateFile != "") return o != nil && (o.Template != "" || o.TemplateFile != "")
} }
// GetAllowedNameOptions returns the AllowedNameOptions, which models the // GetAllowedNameOptions returns the AllowedNames, which models the
// SANs that a provisioner is authorized to sign x509 certificates for. // SANs that a provisioner is authorized to sign x509 certificates for.
func (o *X509Options) GetAllowedNameOptions() *X509NameOptions { func (o *X509Options) GetAllowedNameOptions() *policy.X509NameOptions {
if o == nil { if o == nil {
return nil return nil
} }
return o.AllowedNames return o.AllowedNames
} }
// GetDeniedNameOptions returns the DeniedNameOptions, which models the // GetDeniedNameOptions returns the DeniedNames, which models the
// SANs that a provisioner is NOT authorized to sign x509 certificates for. // SANs that a provisioner is NOT authorized to sign x509 certificates for.
func (o *X509Options) GetDeniedNameOptions() *X509NameOptions { func (o *X509Options) GetDeniedNameOptions() *policy.X509NameOptions {
if o == nil { if o == nil {
return nil return nil
} }
return o.DeniedNames return o.DeniedNames
} }
// X509NameOptions models the X509 name policy configuration.
type X509NameOptions struct {
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
URIDomains []string `json:"uri,omitempty"`
}
// HasNames checks if the AllowedNameOptions has one or more
// names configured.
func (o *X509NameOptions) HasNames() bool {
return len(o.DNSDomains) > 0 ||
len(o.IPRanges) > 0 ||
len(o.EmailAddresses) > 0 ||
len(o.URIDomains) > 0
}
// TemplateOptions generates a CertificateOptions with the template and data // TemplateOptions generates a CertificateOptions with the template and data
// defined in the ProvisionerOptions, the provisioner generated data, and the // defined in the ProvisionerOptions, the provisioner generated data, and the
// user data provided in the request. If no template has been provided, // user data provided in the request. If no template has been provided,

View file

@ -1,156 +0,0 @@
package provisioner
import (
"fmt"
"github.com/smallstep/certificates/policy"
"golang.org/x/crypto/ssh"
)
type sshPolicyEngineType string
const (
userPolicyEngineType sshPolicyEngineType = "user"
hostPolicyEngineType sshPolicyEngineType = "host"
)
var certTypeToPolicyEngineType = map[uint32]sshPolicyEngineType{
uint32(ssh.UserCert): userPolicyEngineType,
uint32(ssh.HostCert): hostPolicyEngineType,
}
type x509PolicyEngine interface {
policy.X509NamePolicyEngine
}
type userPolicyEngine struct {
policy.SSHNamePolicyEngine
}
type hostPolicyEngine struct {
policy.SSHNamePolicyEngine
}
// newX509PolicyEngine creates a new x509 name policy engine
func newX509PolicyEngine(x509Opts *X509Options) (x509PolicyEngine, error) {
if x509Opts == nil {
return nil, nil
}
options := []policy.NamePolicyOption{
policy.WithSubjectCommonNameVerification(), // enable x509 Subject Common Name validation by default
}
allowed := x509Opts.GetAllowedNameOptions()
if allowed != nil && allowed.HasNames() {
options = append(options,
policy.WithPermittedDNSDomains(allowed.DNSDomains),
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges),
policy.WithPermittedEmailAddresses(allowed.EmailAddresses),
policy.WithPermittedURIDomains(allowed.URIDomains),
)
}
denied := x509Opts.GetDeniedNameOptions()
if denied != nil && denied.HasNames() {
options = append(options,
policy.WithExcludedDNSDomains(denied.DNSDomains),
policy.WithExcludedIPsOrCIDRs(denied.IPRanges),
policy.WithExcludedEmailAddresses(denied.EmailAddresses),
policy.WithExcludedURIDomains(denied.URIDomains),
)
}
return policy.New(options...)
}
// newSSHUserPolicyEngine creates a new SSH user certificate policy engine
func newSSHUserPolicyEngine(sshOpts *SSHOptions) (*userPolicyEngine, error) {
policyEngine, err := newSSHPolicyEngine(sshOpts, userPolicyEngineType)
if err != nil {
return nil, err
}
// ensure we're not wrapping a nil engine
if policyEngine == nil {
return nil, nil
}
return &userPolicyEngine{
SSHNamePolicyEngine: policyEngine,
}, nil
}
// newSSHHostPolicyEngine create a new SSH host certificate policy engine
func newSSHHostPolicyEngine(sshOpts *SSHOptions) (*hostPolicyEngine, error) {
policyEngine, err := newSSHPolicyEngine(sshOpts, hostPolicyEngineType)
if err != nil {
return nil, err
}
// ensure we're not wrapping a nil engine
if policyEngine == nil {
return nil, nil
}
return &hostPolicyEngine{
SSHNamePolicyEngine: policyEngine,
}, nil
}
// newSSHPolicyEngine creates a new SSH name policy engine
func newSSHPolicyEngine(sshOpts *SSHOptions, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) {
if sshOpts == nil {
return nil, nil
}
var (
allowed *SSHNameOptions
denied *SSHNameOptions
)
// TODO: embed the type in the policy engine itself for reference?
switch typ {
case userPolicyEngineType:
if sshOpts.User != nil {
allowed = sshOpts.User.GetAllowedNameOptions()
denied = sshOpts.User.GetDeniedNameOptions()
}
case hostPolicyEngineType:
if sshOpts.Host != nil {
allowed = sshOpts.Host.AllowedNames
denied = sshOpts.Host.DeniedNames
}
default:
return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ)
}
options := []policy.NamePolicyOption{}
if allowed != nil && allowed.HasNames() {
options = append(options,
policy.WithPermittedDNSDomains(allowed.DNSDomains),
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges),
policy.WithPermittedEmailAddresses(allowed.EmailAddresses),
policy.WithPermittedPrincipals(allowed.Principals),
)
}
if denied != nil && denied.HasNames() {
options = append(options,
policy.WithExcludedDNSDomains(denied.DNSDomains),
policy.WithExcludedIPsOrCIDRs(denied.IPRanges),
policy.WithExcludedEmailAddresses(denied.EmailAddresses),
policy.WithExcludedPrincipals(denied.Principals),
)
}
// Return nil, because there's no policy to execute. This is
// important, because the logic that determines user vs. host certs
// are allowed depends on this fact. The two policy engines are
// not aware of eachother, so this check is performed in the
// SSH name validator, instead.
if len(options) == 0 {
return nil, nil
}
return policy.New(options...)
}

View file

@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
stderrors "errors" stderrors "errors"
"net/url" "net/url"
"regexp"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -47,6 +46,7 @@ var ErrAllowTokenReuse = stderrors.New("allow token reuse")
// Audiences stores all supported audiences by request type. // Audiences stores all supported audiences by request type.
type Audiences struct { type Audiences struct {
Sign []string Sign []string
Renew []string
Revoke []string Revoke []string
SSHSign []string SSHSign []string
SSHRevoke []string SSHRevoke []string
@ -57,6 +57,7 @@ type Audiences struct {
// All returns all supported audiences across all request types in one list. // All returns all supported audiences across all request types in one list.
func (a Audiences) All() (auds []string) { func (a Audiences) All() (auds []string) {
auds = a.Sign auds = a.Sign
auds = append(auds, a.Renew...)
auds = append(auds, a.Revoke...) auds = append(auds, a.Revoke...)
auds = append(auds, a.SSHSign...) auds = append(auds, a.SSHSign...)
auds = append(auds, a.SSHRevoke...) auds = append(auds, a.SSHRevoke...)
@ -70,6 +71,7 @@ func (a Audiences) All() (auds []string) {
func (a Audiences) WithFragment(fragment string) Audiences { func (a Audiences) WithFragment(fragment string) Audiences {
ret := Audiences{ ret := Audiences{
Sign: make([]string, len(a.Sign)), Sign: make([]string, len(a.Sign)),
Renew: make([]string, len(a.Renew)),
Revoke: make([]string, len(a.Revoke)), Revoke: make([]string, len(a.Revoke)),
SSHSign: make([]string, len(a.SSHSign)), SSHSign: make([]string, len(a.SSHSign)),
SSHRevoke: make([]string, len(a.SSHRevoke)), SSHRevoke: make([]string, len(a.SSHRevoke)),
@ -83,6 +85,13 @@ func (a Audiences) WithFragment(fragment string) Audiences {
ret.Sign[i] = s ret.Sign[i] = s
} }
} }
for i, s := range a.Renew {
if u, err := url.Parse(s); err == nil {
ret.Renew[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
} else {
ret.Renew[i] = s
}
}
for i, s := range a.Revoke { for i, s := range a.Revoke {
if u, err := url.Parse(s); err == nil { if u, err := url.Parse(s); err == nil {
ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
@ -210,6 +219,12 @@ type Config struct {
// GetIdentityFunc is a function that returns an identity that will be // GetIdentityFunc is a function that returns an identity that will be
// used by the provisioner to populate certificate attributes. // used by the provisioner to populate certificate attributes.
GetIdentityFunc GetIdentityFunc GetIdentityFunc GetIdentityFunc
// AuthorizeRenewFunc is a function that returns nil if a given X.509
// certificate can be renewed.
AuthorizeRenewFunc AuthorizeRenewFunc
// AuthorizeSSHRenewFunc is a function that returns nil if a given SSH
// certificate can be renewed.
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
} }
type provisioner struct { type provisioner struct {
@ -278,32 +293,6 @@ func (l *List) UnmarshalJSON(data []byte) error {
return nil return nil
} }
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
// SanitizeSSHUserPrincipal grabs an email or a string with the format
// local@domain and returns a sanitized version of the local, valid to be used
// as a user name. If the email starts with a letter between a and z, the
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
func SanitizeSSHUserPrincipal(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i]
}
return strings.Map(func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= '0' && r <= '9':
return r
case r == '-':
return '-'
case r == '.': // drop dots
return -1
default:
return '_'
}
}, strings.ToLower(email))
}
type base struct{} type base struct{}
// AuthorizeSign returns an unimplemented error. Provisioners should overwrite // AuthorizeSign returns an unimplemented error. Provisioners should overwrite
@ -348,66 +337,12 @@ func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certif
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented") return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
} }
// Identity is the type representing an externally supplied identity that is used
// by provisioners to populate certificate fields.
type Identity struct {
Usernames []string `json:"usernames"`
Permissions `json:"permissions"`
}
// Permissions defines extra extensions and critical options to grant to an SSH certificate. // Permissions defines extra extensions and critical options to grant to an SSH certificate.
type Permissions struct { type Permissions struct {
Extensions map[string]string `json:"extensions"` Extensions map[string]string `json:"extensions"`
CriticalOptions map[string]string `json:"criticalOptions"` CriticalOptions map[string]string `json:"criticalOptions"`
} }
// GetIdentityFunc is a function that returns an identity.
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
// DefaultIdentityFunc return a default identity depending on the provisioner
// type. For OIDC email is always present and the usernames might
// contain empty strings.
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
switch k := p.(type) {
case *OIDC:
// OIDC principals would be:
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
// 2. Sanitized local.
// 3. Raw local (if different).
// 4. Email address.
name := SanitizeSSHUserPrincipal(email)
if !sshUserRegex.MatchString(name) {
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
}
usernames := []string{name}
if i := strings.LastIndex(email, "@"); i >= 0 {
usernames = append(usernames, email[:i])
}
usernames = append(usernames, email)
return &Identity{
Usernames: SanitizeStringSlices(usernames),
}, nil
default:
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
}
}
// SanitizeStringSlices removes duplicated an empty strings.
func SanitizeStringSlices(original []string) []string {
output := []string{}
seen := make(map[string]struct{})
for _, entry := range original {
if entry == "" {
continue
}
if _, value := seen[entry]; !value {
seen[entry] = struct{}{}
output = append(output, entry)
}
}
return output
}
// MockProvisioner for testing // MockProvisioner for testing
type MockProvisioner struct { type MockProvisioner struct {
Mret1, Mret2, Mret3 interface{} Mret1, Mret2, Mret3 interface{}

View file

@ -5,33 +5,36 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/policy"
"github.com/smallstep/certificates/authority/policy"
) )
// SCEP is the SCEP provisioner type, an entity that can authorize the // SCEP is the SCEP provisioner type, an entity that can authorize the
// SCEP provisioning flow // SCEP provisioning flow
type SCEP struct { type SCEP struct {
*base *base
ID string `json:"-"` ID string `json:"-"`
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
ForceCN bool `json:"forceCN,omitempty"` ForceCN bool `json:"forceCN,omitempty"`
ChallengePassword string `json:"challenge,omitempty"` ChallengePassword string `json:"challenge,omitempty"`
Capabilities []string `json:"capabilities,omitempty"` Capabilities []string `json:"capabilities,omitempty"`
// IncludeRoot makes the provisioner return the CA root in addition to the // IncludeRoot makes the provisioner return the CA root in addition to the
// intermediate in the GetCACerts response // intermediate in the GetCACerts response
IncludeRoot bool `json:"includeRoot,omitempty"` IncludeRoot bool `json:"includeRoot,omitempty"`
// MinimumPublicKeyLength is the minimum length for public keys in CSRs // MinimumPublicKeyLength is the minimum length for public keys in CSRs
MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"`
// Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7
// at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63
// Defaults to 0, being DES-CBC // Defaults to 0, being DES-CBC
EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
claimer *Claimer ctl *Controller
x509Policy policy.X509NamePolicyEngine x509Policy policy.X509Policy
secretChallengePassword string secretChallengePassword string
encryptionAlgorithm int encryptionAlgorithm int
} }
@ -78,7 +81,7 @@ func (s *SCEP) GetOptions() *Options {
// DefaultTLSCertDuration returns the default TLS cert duration enforced by // DefaultTLSCertDuration returns the default TLS cert duration enforced by
// the provisioner. // the provisioner.
func (s *SCEP) DefaultTLSCertDuration() time.Duration { func (s *SCEP) DefaultTLSCertDuration() time.Duration {
return s.claimer.DefaultTLSCertDuration() return s.ctl.Claimer.DefaultTLSCertDuration()
} }
// Init initializes and validates the fields of a SCEP type. // Init initializes and validates the fields of a SCEP type.
@ -90,11 +93,6 @@ func (s *SCEP) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty") return errors.New("provisioner name cannot be empty")
} }
// Update claims with global ones
if s.claimer, err = NewClaimer(s.Claims, config.Claims); err != nil {
return err
}
// Mask the actual challenge value, so it won't be marshaled // Mask the actual challenge value, so it won't be marshaled
s.secretChallengePassword = s.ChallengePassword s.secretChallengePassword = s.ChallengePassword
s.ChallengePassword = "*** redacted ***" s.ChallengePassword = "*** redacted ***"
@ -116,11 +114,12 @@ func (s *SCEP) Init(config Config) (err error) {
// TODO: add other, SCEP specific, options? // TODO: add other, SCEP specific, options?
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if s.x509Policy, err = newX509PolicyEngine(s.Options.GetX509Options()); err != nil { if s.x509Policy, err = policy.NewX509PolicyEngine(s.Options.GetX509Options()); err != nil {
return err return err
} }
return err s.ctl, err = NewController(s, s.Claims, config)
return
} }
// AuthorizeSign does not do any verification, because all verification is handled // AuthorizeSign does not do any verification, because all verification is handled
@ -131,10 +130,10 @@ func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeSCEP, s.Name, ""), newProvisionerExtensionOption(TypeSCEP, s.Name, ""),
newForceCNOption(s.ForceCN), newForceCNOption(s.ForceCN),
profileDefaultDuration(s.claimer.DefaultTLSCertDuration()), profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength),
newValidityValidator(s.claimer.MinTLSCertDuration(), s.claimer.MaxTLSCertDuration()), newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(s.x509Policy), newX509NamePolicyValidator(s.x509Policy),
}, nil }, nil
} }

View file

@ -6,7 +6,6 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1"
"encoding/json" "encoding/json"
"net" "net"
"net/http" "net/http"
@ -14,10 +13,11 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
) )
// DefaultCertValidity is the default validity for a certificate if none is specified. // DefaultCertValidity is the default validity for a certificate if none is specified.
@ -407,18 +407,17 @@ func (v *validityValidator) Valid(cert *x509.Certificate, o SignOptions) error {
// x509NamePolicyValidator validates that the certificate (to be signed) // x509NamePolicyValidator validates that the certificate (to be signed)
// contains only allowed SANs. // contains only allowed SANs.
type x509NamePolicyValidator struct { type x509NamePolicyValidator struct {
policyEngine x509PolicyEngine policyEngine policy.X509Policy
} }
// newX509NamePolicyValidator return a new SANs allow/deny validator. // newX509NamePolicyValidator return a new SANs allow/deny validator.
func newX509NamePolicyValidator(engine x509PolicyEngine) *x509NamePolicyValidator { func newX509NamePolicyValidator(engine policy.X509Policy) *x509NamePolicyValidator {
return &x509NamePolicyValidator{ return &x509NamePolicyValidator{
policyEngine: engine, policyEngine: engine,
} }
} }
// Valid validates validates that the certificate (to be signed) // Valid validates that the certificate (to be signed) contains only allowed SANs.
// contains only allowed SANs.
func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) error { func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) error {
if v.policyEngine == nil { if v.policyEngine == nil {
return nil return nil
@ -427,17 +426,17 @@ func (v *x509NamePolicyValidator) Valid(cert *x509.Certificate, _ SignOptions) e
return err return err
} }
var ( // var (
stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64} // stepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) // stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
) // )
type stepProvisionerASN1 struct { // type stepProvisionerASN1 struct {
Type int // Type int
Name []byte // Name []byte
CredentialID []byte // CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"` // KeyValuePairs []string `asn1:"optional,omitempty"`
} // }
type forceCNOption struct { type forceCNOption struct {
ForceCN bool ForceCN bool
@ -464,23 +463,22 @@ func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error {
} }
type provisionerExtensionOption struct { type provisionerExtensionOption struct {
Type int Extension
Name string
CredentialID string
KeyValuePairs []string
} }
func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption { func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption {
return &provisionerExtensionOption{ return &provisionerExtensionOption{
Type: int(typ), Extension: Extension{
Name: name, Type: typ,
CredentialID: credentialID, Name: name,
KeyValuePairs: keyValuePairs, CredentialID: credentialID,
KeyValuePairs: keyValuePairs,
},
} }
} }
func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error { func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error {
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...) ext, err := o.ToExtension()
if err != nil { if err != nil {
return errs.NewError(http.StatusInternalServerError, err, "error creating certificate") return errs.NewError(http.StatusInternalServerError, err, "error creating certificate")
} }
@ -494,20 +492,3 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption
cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...) cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...)
return nil return nil
} }
func createProvisionerExtension(typ int, name, credentialID string, keyValuePairs ...string) (pkix.Extension, error) {
b, err := asn1.Marshal(stepProvisionerASN1{
Type: typ,
Name: []byte(name),
CredentialID: []byte(credentialID),
KeyValuePairs: keyValuePairs,
})
if err != nil {
return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension")
}
return pkix.Extension{
Id: stepOIDProvisioner,
Critical: false,
Value: b,
}, nil
}

View file

@ -636,18 +636,18 @@ func Test_newProvisionerExtension_Option(t *testing.T) {
valid: func(cert *x509.Certificate) { valid: func(cert *x509.Certificate) {
if assert.Len(t, 1, cert.ExtraExtensions) { if assert.Len(t, 1, cert.ExtraExtensions) {
ext := cert.ExtraExtensions[0] ext := cert.ExtraExtensions[0]
assert.Equals(t, ext.Id, stepOIDProvisioner) assert.Equals(t, ext.Id, StepOIDProvisioner)
} }
}, },
} }
}, },
"ok/prepend": func() test { "ok/prepend": func() test {
return test{ return test{
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: stepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: StepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}},
valid: func(cert *x509.Certificate) { valid: func(cert *x509.Certificate) {
if assert.Len(t, 3, cert.ExtraExtensions) { if assert.Len(t, 3, cert.ExtraExtensions) {
ext := cert.ExtraExtensions[0] ext := cert.ExtraExtensions[0]
assert.Equals(t, ext.Id, stepOIDProvisioner) assert.Equals(t, ext.Id, StepOIDProvisioner)
assert.False(t, ext.Critical) assert.False(t, ext.Critical)
} }
}, },

View file

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -448,20 +449,19 @@ func (v sshDefaultPublicKeyValidator) Valid(cert *ssh.Certificate, o SignSSHOpti
// sshNamePolicyValidator validates that the certificate (to be signed) // sshNamePolicyValidator validates that the certificate (to be signed)
// contains only allowed principals. // contains only allowed principals.
type sshNamePolicyValidator struct { type sshNamePolicyValidator struct {
hostPolicyEngine *hostPolicyEngine hostPolicyEngine policy.HostPolicy
userPolicyEngine *userPolicyEngine userPolicyEngine policy.UserPolicy
} }
// newSSHNamePolicyValidator return a new SSH allow/deny validator. // newSSHNamePolicyValidator return a new SSH allow/deny validator.
func newSSHNamePolicyValidator(host *hostPolicyEngine, user *userPolicyEngine) *sshNamePolicyValidator { func newSSHNamePolicyValidator(host policy.HostPolicy, user policy.UserPolicy) *sshNamePolicyValidator {
return &sshNamePolicyValidator{ return &sshNamePolicyValidator{
hostPolicyEngine: host, hostPolicyEngine: host,
userPolicyEngine: user, userPolicyEngine: user,
} }
} }
// Valid validates validates that the certificate (to be signed) // Valid validates that the certificate (to be signed) contains only allowed principals.
// contains only allowed principals.
func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error { func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions) error {
if v.hostPolicyEngine == nil && v.userPolicyEngine == nil { if v.hostPolicyEngine == nil && v.userPolicyEngine == nil {
// no policy configured at all; allow anything // no policy configured at all; allow anything
@ -473,29 +473,25 @@ func (v *sshNamePolicyValidator) Valid(cert *ssh.Certificate, _ SignSSHOptions)
// the same for host certs: if only a user policy engine is configured, host // the same for host certs: if only a user policy engine is configured, host
// certs are denied. When both policy engines are configured, the type of // certs are denied. When both policy engines are configured, the type of
// cert determines which policy engine is used. // cert determines which policy engine is used.
policyType, ok := certTypeToPolicyEngineType[cert.CertType] switch cert.CertType {
if !ok { case ssh.HostCert:
return fmt.Errorf("unexpected SSH cert type %d", cert.CertType)
}
switch policyType {
case hostPolicyEngineType:
// when no host policy engine is configured, but a user policy engine is // when no host policy engine is configured, but a user policy engine is
// configured, we don't allow the host certificate. // configured, the host certificate is denied.
if v.hostPolicyEngine == nil && v.userPolicyEngine != nil { if v.hostPolicyEngine == nil && v.userPolicyEngine != nil {
return errors.New("SSH host certificate not authorized") // TODO: include principals in message? return errors.New("SSH host certificate not authorized")
} }
_, err := v.hostPolicyEngine.ArePrincipalsAllowed(cert) _, err := v.hostPolicyEngine.ArePrincipalsAllowed(cert)
return err return err
case userPolicyEngineType: case ssh.UserCert:
// when no user policy engine is configured, but a host policy engine is // when no user policy engine is configured, but a host policy engine is
// configured, we don't allow the user certificate. // configured, the user certificate is denied.
if v.userPolicyEngine == nil && v.hostPolicyEngine != nil { if v.userPolicyEngine == nil && v.hostPolicyEngine != nil {
return errors.New("SSH user certificate not authorized") // TODO: include principals in message? return errors.New("SSH user certificate not authorized")
} }
_, err := v.userPolicyEngine.ArePrincipalsAllowed(cert) _, err := v.userPolicyEngine.ArePrincipalsAllowed(cert)
return err return err
default: default:
return fmt.Errorf("unexpected policy engine type %q", policyType) // satisfy return; shouldn't happen return fmt.Errorf("unexpected SSH certificate type %d", cert.CertType) // satisfy return; shouldn't happen
} }
} }

View file

@ -685,7 +685,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) {
func Test_sshCertValidityValidator(t *testing.T) { func Test_sshCertValidityValidator(t *testing.T) {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
v := sshCertValidityValidator{p.claimer} v := sshCertValidityValidator{p.ctl.Claimer}
n := now() n := now()
tests := []struct { tests := []struct {
name string name string
@ -806,7 +806,7 @@ func Test_sshValidityModifier(t *testing.T) {
tests := map[string]func() test{ tests := map[string]func() test{
"fail/type-not-set": func() test { "fail/type-not-set": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
@ -816,7 +816,7 @@ func Test_sshValidityModifier(t *testing.T) {
}, },
"fail/type-not-recognized": func() test { "fail/type-not-recognized": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 4, CertType: 4,
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
@ -827,7 +827,7 @@ func Test_sshValidityModifier(t *testing.T) {
}, },
"fail/requested-validAfter-after-limit": func() test { "fail/requested-validAfter-after-limit": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: uint64(n.Add(2 * time.Hour).Unix()), ValidAfter: uint64(n.Add(2 * time.Hour).Unix()),
@ -838,7 +838,7 @@ func Test_sshValidityModifier(t *testing.T) {
}, },
"fail/requested-validBefore-after-limit": func() test { "fail/requested-validBefore-after-limit": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
@ -850,7 +850,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/no-limit": func() test { "ok/no-limit": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
}, },
@ -863,7 +863,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/defaults": func() test { "ok/defaults": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
}, },
@ -876,7 +876,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/valid-requested-validBefore": func() test { "ok/valid-requested-validBefore": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix()) va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: va, ValidAfter: va,
@ -891,7 +891,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/empty-requested-validBefore-limit-after-default": func() test { "ok/empty-requested-validBefore-limit-after-default": func() test {
va := uint64(n.Unix()) va := uint64(n.Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: va, ValidAfter: va,
@ -905,7 +905,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/empty-requested-validBefore-limit-before-default": func() test { "ok/empty-requested-validBefore-limit-before-default": func() test {
va := uint64(n.Unix()) va := uint64(n.Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: va, ValidAfter: va,

View file

@ -6,6 +6,8 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"github.com/smallstep/certificates/authority/policy"
) )
// SSHCertificateOptions is an interface that returns a list of options passed when // SSHCertificateOptions is an interface that returns a list of options passed when
@ -35,32 +37,58 @@ type SSHOptions struct {
TemplateData json.RawMessage `json:"templateData,omitempty"` TemplateData json.RawMessage `json:"templateData,omitempty"`
// User contains SSH user certificate options. // User contains SSH user certificate options.
User *SSHUserCertificateOptions `json:"user,omitempty"` User *policy.SSHUserCertificateOptions
// Host contains SSH host certificate options. // Host contains SSH host certificate options.
Host *SSHHostCertificateOptions `json:"host,omitempty"` Host *policy.SSHHostCertificateOptions
} }
// SSHUserCertificateOptions is a collection of SSH user certificate options. // GetAllowedUserNameOptions returns the SSHNameOptions that are
type SSHUserCertificateOptions struct { // allowed when SSH User certificates are requested.
// AllowedNames contains the names the provisioner is authorized to sign func (o *SSHOptions) GetAllowedUserNameOptions() *policy.SSHNameOptions {
AllowedNames *SSHNameOptions `json:"allow,omitempty"` if o == nil {
return nil
// DeniedNames contains the names the provisioner is not authorized to sign }
DeniedNames *SSHNameOptions `json:"deny,omitempty"` if o.User == nil {
return nil
}
return o.User.AllowedNames
} }
// SSHHostCertificateOptions is a collection of SSH host certificate options. // GetDeniedUserNameOptions returns the SSHNameOptions that are
// It's an alias of SSHUserCertificateOptions, as the options are the same // denied when SSH user certificates are requested.
// for both types of certificates. func (o *SSHOptions) GetDeniedUserNameOptions() *policy.SSHNameOptions {
type SSHHostCertificateOptions SSHUserCertificateOptions if o == nil {
return nil
}
if o.User == nil {
return nil
}
return o.User.DeniedNames
}
// SSHNameOptions models the SSH name policy configuration. // GetAllowedHostNameOptions returns the SSHNameOptions that are
type SSHNameOptions struct { // allowed when SSH host certificates are requested.
DNSDomains []string `json:"dns,omitempty"` func (o *SSHOptions) GetAllowedHostNameOptions() *policy.SSHNameOptions {
IPRanges []string `json:"ip,omitempty"` if o == nil {
EmailAddresses []string `json:"email,omitempty"` return nil
Principals []string `json:"principal,omitempty"` }
if o.Host == nil {
return nil
}
return o.Host.AllowedNames
}
// GetDeniedHostNameOptions returns the SSHNameOptions that are
// denied when SSH host certificates are requested.
func (o *SSHOptions) GetDeniedHostNameOptions() *policy.SSHNameOptions {
if o == nil {
return nil
}
if o.Host == nil {
return nil
}
return o.Host.DeniedNames
} }
// HasTemplate returns true if a template is defined in the provisioner options. // HasTemplate returns true if a template is defined in the provisioner options.
@ -68,32 +96,6 @@ func (o *SSHOptions) HasTemplate() bool {
return o != nil && (o.Template != "" || o.TemplateFile != "") return o != nil && (o.Template != "" || o.TemplateFile != "")
} }
// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the
// names that a provisioner is authorized to sign SSH certificates for.
func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
return o.AllowedNames
}
// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the
// names that a provisioner is NOT authorized to sign SSH certificates for.
func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
return o.DeniedNames
}
// HasNames checks if the SSHNameOptions has one or more
// names configured.
func (o *SSHNameOptions) HasNames() bool {
return len(o.DNSDomains) > 0 ||
len(o.EmailAddresses) > 0 ||
len(o.Principals) > 0
}
// TemplateSSHOptions generates a SSHCertificateOptions with the template and // TemplateSSHOptions generates a SSHCertificateOptions with the template and
// data defined in the ProvisionerOptions, the provisioner generated data, and // data defined in the ProvisionerOptions, the provisioner generated data, and
// the user data provided in the request. If no template has been provided, // the user data provided in the request. If no template has been provided,

View file

@ -29,8 +29,7 @@ type SSHPOP struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
claimer *Claimer ctl *Controller
audiences Audiences
sshPubKeys *SSHKeys sshPubKeys *SSHKeys
} }
@ -83,7 +82,7 @@ func (p *SSHPOP) GetEncryptedKey() (string, string, bool) {
} }
// Init initializes and validates the fields of a SSHPOP type. // Init initializes and validates the fields of a SSHPOP type.
func (p *SSHPOP) Init(config Config) error { func (p *SSHPOP) Init(config Config) (err error) {
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -93,17 +92,15 @@ func (p *SSHPOP) Init(config Config) error {
return errors.New("provisioner public SSH validation keys cannot be empty") return errors.New("provisioner public SSH validation keys cannot be empty")
} }
// Update claims with global ones
var err error
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// TODO(hs): initialize the policy engine and add it as an SSH cert validator // TODO(hs): initialize the policy engine and add it as an SSH cert validator
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.sshPubKeys = config.SSHKeys p.sshPubKeys = config.SSHKeys
return nil
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
// Update claims with global ones
p.ctl, err = NewController(p, p.Claims, config)
return
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -111,7 +108,7 @@ func (p *SSHPOP) Init(config Config) error {
// e.g. a Sign request will auth/validate different fields than a Revoke request. // e.g. a Sign request will auth/validate different fields than a Revoke request.
// //
// Checking for certificate revocation has been moved to the authority package. // Checking for certificate revocation has been moved to the authority package.
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) { func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity bool) (*sshPOPPayload, error) {
sshCert, jwt, err := ExtractSSHPOPCert(token) sshCert, jwt, err := ExtractSSHPOPCert(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, return nil, errs.Wrap(http.StatusUnauthorized, err,
@ -119,13 +116,18 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
} }
// Check validity period of the certificate. // Check validity period of the certificate.
n := time.Now() //
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { // Controller.AuthorizeSSHRenew will validate this on the renewal flow.
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") if checkValidity {
} unixNow := time.Now().Unix()
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) { if after := int64(sshCert.ValidAfter); after < 0 || unixNow < int64(sshCert.ValidAfter) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
}
if before := int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
}
} }
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
if !ok { if !ok {
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey") return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
@ -188,7 +190,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
// AuthorizeSSHRevoke validates the authorization token and extracts/validates // AuthorizeSSHRevoke validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
claims, err := p.authorizeToken(token, p.audiences.SSHRevoke) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
} }
@ -201,22 +203,20 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
// AuthorizeSSHRenew validates the authorization token and extracts/validates // AuthorizeSSHRenew validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRenew) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew, false)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
} }
if claims.sshCert.CertType != ssh.HostCert { if claims.sshCert.CertType != ssh.HostCert {
return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, p.ctl.AuthorizeSSHRenew(ctx, claims.sshCert)
return claims.sshCert, nil
} }
// AuthorizeSSHRekey validates the authorization token and extracts/validates // AuthorizeSSHRekey validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRekey) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true)
if err != nil { if err != nil {
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
} }
@ -227,11 +227,10 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
}, nil }, nil
} }
// ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate // ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate

View file

@ -38,6 +38,7 @@ func TestSSHPOP_Getters(t *testing.T) {
} }
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
now := time.Now()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -46,6 +47,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if cert.ValidAfter == 0 {
cert.ValidAfter = uint64(now.Unix())
}
if cert.ValidBefore == 0 {
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
}
if err := cert.SignCert(rand.Reader, signer); err != nil { if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -207,7 +214,7 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
@ -455,7 +462,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
case *sshDefaultPublicKeyValidator: case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator: case *sshCertDefaultValidator:
case *sshCertValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
} }

View file

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDeTCCAx+gAwIBAgIRAOTItW2pYuSU+PkmLW090iUwCgYIKoZIzj0EAwIwJDEi
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjUy
MjBaFw0yMjAzMTIyMjUzMjBaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
KoZIzj0DAQcDQgAE/9vvOZ1Zzysnf3VeGyotMJEMZdAborB36Ah5QL/3yQNMRWIc
pv9Dwx19pHw7SquVE8jIaPPJSjaeWnfMPDYDxaOCAbcwggGzMA4GA1UdDwEB/wQE
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
ADAdBgNVHQ4EFgQUkJUg6AsqWlqTZt6BHidRMwh1vKYwHwYDVR0jBBgwFoAUDpTg
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
LXNlcnZlci1nMy5jcmwwFwYMKwYBBAGCpGTGKEABBAdmb29vYmFyMAoGCCqGSM49
BAMCA0gAMEUCIQCWYqOuk4bLkVVeHvo3P8TlJJ3fw6ijDDLstvdrQqAl5wIgEjSY
wVcR649Oc8PJGh/43Kpx0+4OTYPQrD/JqphVF7g=
-----END CERTIFICATE-----

View file

@ -0,0 +1,22 @@
-----BEGIN CERTIFICATE-----
MIIDujCCA2GgAwIBAgIRAM5celDKTTqAGycljO7FZdEwCgYIKoZIzj0EAwIwJDEi
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjQx
MDRaFw0yMjAzMTIyMjQyMDRaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
KoZIzj0DAQcDQgAEkXffZYlSJRMxJrZHmUpEMC4jQYCkF86mLJY0iLZ8k00N/xF0
4rAGwzTU/l9tfRpNl+z/XfMMWPXS0Q8NU/o4S6OCAfkwggH1MA4GA1UdDwEB/wQE
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
ADAdBgNVHQ4EFgQUL3sSlYW8Tf2l2P+gFTdn5wsUjfgwHwYDVR0jBBgwFoAUDpTg
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
LXNlcnZlci1nMy5jcmwwWQYMKwYBBAGCpGTGKEABBEkwRwIBAQQVbWFyaWFub0Bz
bWFsbHN0ZXAuY29tBCtudmduUjh3U3pwVWxydF90QzNtdnJod2hCeDlZN1QxV0xf
SmpjRlZXWUJRMAoGCCqGSM49BAMCA0cAMEQCIE6umrhSbeQWWVK5cWBvXj5c0cGB
bUF0rNw/dsaCaWcwAiAKSkmjhsC63DVPXPCNUki90YgVovO69foO1ZaB43lx5w==
-----END CERTIFICATE-----

View file

@ -24,20 +24,22 @@ import (
) )
var ( var (
defaultDisableRenewal = false defaultDisableRenewal = false
defaultEnableSSHCA = true defaultAllowRenewAfterExpiry = false
globalProvisionerClaims = Claims{ defaultEnableSSHCA = true
MinTLSDur: &Duration{5 * time.Minute}, globalProvisionerClaims = Claims{
MaxTLSDur: &Duration{24 * time.Hour}, MinTLSDur: &Duration{5 * time.Minute},
DefaultTLSDur: &Duration{24 * time.Hour}, MaxTLSDur: &Duration{24 * time.Hour},
DisableRenewal: &defaultDisableRenewal, DefaultTLSDur: &Duration{24 * time.Hour},
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &defaultEnableSSHCA, EnableSSHCA: &defaultEnableSSHCA,
DisableRenewal: &defaultDisableRenewal,
AllowRenewAfterExpiry: &defaultAllowRenewAfterExpiry,
} }
testAudiences = Audiences{ testAudiences = Audiences{
Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"}, Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"},
@ -172,19 +174,18 @@ func generateJWK() (*JWK, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil { p := &JWK{
return nil, err
}
return &JWK{
Name: name, Name: name,
Type: "JWK", Type: "JWK",
Key: &public, Key: &public,
EncryptedKey: encrypted, EncryptedKey: encrypted,
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences, }
claimer: claimer, p.ctl, err = NewController(p, p.Claims, Config{
}, nil Audiences: testAudiences,
})
return p, err
} }
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
@ -205,23 +206,21 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
pubKeys := []interface{}{fooPub, barPub} pubKeys := []interface{}{fooPub, barPub}
if inputPubKey != nil { if inputPubKey != nil {
pubKeys = append(pubKeys, inputPubKey) pubKeys = append(pubKeys, inputPubKey)
} }
return &K8sSA{ p := &K8sSA{
Name: K8sSAName, Name: K8sSAName,
Type: "K8sSA", Type: "K8sSA",
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences, pubKeys: pubKeys,
claimer: claimer, }
pubKeys: pubKeys, p.ctl, err = NewController(p, p.Claims, Config{
}, nil Audiences: testAudiences,
})
return p, err
} }
func generateSSHPOP() (*SSHPOP, error) { func generateSSHPOP() (*SSHPOP, error) {
@ -229,11 +228,6 @@ func generateSSHPOP() (*SSHPOP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub") userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
if err != nil { if err != nil {
return nil, err return nil, err
@ -251,17 +245,19 @@ func generateSSHPOP() (*SSHPOP, error) {
return nil, err return nil, err
} }
return &SSHPOP{ p := &SSHPOP{
Name: name, Name: name,
Type: "SSHPOP", Type: "SSHPOP",
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences,
claimer: claimer,
sshPubKeys: &SSHKeys{ sshPubKeys: &SSHKeys{
UserKeys: []ssh.PublicKey{userKey}, UserKeys: []ssh.PublicKey{userKey},
HostKeys: []ssh.PublicKey{hostKey}, HostKeys: []ssh.PublicKey{hostKey},
}, },
}, nil }
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
} }
func generateX5C(root []byte) (*X5C, error) { func generateX5C(root []byte) (*X5C, error) {
@ -283,11 +279,6 @@ M46l92gdOozT
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
rootPool := x509.NewCertPool() rootPool := x509.NewCertPool()
var ( var (
@ -305,15 +296,17 @@ M46l92gdOozT
} }
rootPool.AddCert(cert) rootPool.AddCert(cert)
} }
return &X5C{ p := &X5C{
Name: name, Name: name,
Type: "X5C", Type: "X5C",
Roots: root, Roots: root,
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences, rootPool: rootPool,
claimer: claimer, }
rootPool: rootPool, p.ctl, err = NewController(p, p.Claims, Config{
}, nil Audiences: testAudiences,
})
return p, err
} }
func generateOIDC() (*OIDC, error) { func generateOIDC() (*OIDC, error) {
@ -333,11 +326,7 @@ func generateOIDC() (*OIDC, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims) p := &OIDC{
if err != nil {
return nil, err
}
return &OIDC{
Name: name, Name: name,
Type: "OIDC", Type: "OIDC",
ClientID: clientID, ClientID: clientID,
@ -351,8 +340,11 @@ func generateOIDC() (*OIDC, error) {
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
claimer: claimer, }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
} }
func generateGCP() (*GCP, error) { func generateGCP() (*GCP, error) {
@ -368,23 +360,21 @@ func generateGCP() (*GCP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims) p := &GCP{
if err != nil {
return nil, err
}
return &GCP{
Type: "GCP", Type: "GCP",
Name: name, Name: name,
ServiceAccounts: []string{serviceAccount}, ServiceAccounts: []string{serviceAccount},
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
claimer: claimer,
config: newGCPConfig(), config: newGCPConfig(),
keyStore: &keyStore{ keyStore: &keyStore{
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
audiences: testAudiences.WithFragment("gcp/" + name), }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("gcp/" + name),
})
return p, err
} }
func generateAWS() (*AWS, error) { func generateAWS() (*AWS, error) {
@ -396,10 +386,6 @@ func generateAWS() (*AWS, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(awsTestCertificate)) block, _ := pem.Decode([]byte(awsTestCertificate))
if block == nil || block.Type != "CERTIFICATE" { if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate") return nil, errors.New("error decoding AWS certificate")
@ -408,13 +394,12 @@ func generateAWS() (*AWS, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate") return nil, errors.Wrap(err, "error parsing AWS certificate")
} }
return &AWS{ p := &AWS{
Type: "AWS", Type: "AWS",
Name: name, Name: name,
Accounts: []string{accountID}, Accounts: []string{accountID},
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
IMDSVersions: []string{"v2", "v1"}, IMDSVersions: []string{"v2", "v1"},
claimer: claimer,
config: &awsConfig{ config: &awsConfig{
identityURL: awsIdentityURL, identityURL: awsIdentityURL,
signatureURL: awsSignatureURL, signatureURL: awsSignatureURL,
@ -423,8 +408,11 @@ func generateAWS() (*AWS, error) {
certificates: []*x509.Certificate{cert}, certificates: []*x509.Certificate{cert},
signatureAlgorithm: awsSignatureAlgorithm, signatureAlgorithm: awsSignatureAlgorithm,
}, },
audiences: testAudiences.WithFragment("aws/" + name), }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("aws/" + name),
})
return p, err
} }
func generateAWSWithServer() (*AWS, *httptest.Server, error) { func generateAWSWithServer() (*AWS, *httptest.Server, error) {
@ -505,10 +493,6 @@ func generateAWSV1Only() (*AWS, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(awsTestCertificate)) block, _ := pem.Decode([]byte(awsTestCertificate))
if block == nil || block.Type != "CERTIFICATE" { if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate") return nil, errors.New("error decoding AWS certificate")
@ -517,13 +501,12 @@ func generateAWSV1Only() (*AWS, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate") return nil, errors.Wrap(err, "error parsing AWS certificate")
} }
return &AWS{ p := &AWS{
Type: "AWS", Type: "AWS",
Name: name, Name: name,
Accounts: []string{accountID}, Accounts: []string{accountID},
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
IMDSVersions: []string{"v1"}, IMDSVersions: []string{"v1"},
claimer: claimer,
config: &awsConfig{ config: &awsConfig{
identityURL: awsIdentityURL, identityURL: awsIdentityURL,
signatureURL: awsSignatureURL, signatureURL: awsSignatureURL,
@ -532,8 +515,11 @@ func generateAWSV1Only() (*AWS, error) {
certificates: []*x509.Certificate{cert}, certificates: []*x509.Certificate{cert},
signatureAlgorithm: awsSignatureAlgorithm, signatureAlgorithm: awsSignatureAlgorithm,
}, },
audiences: testAudiences.WithFragment("aws/" + name), }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("aws/" + name),
})
return p, err
} }
func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) { func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) {
@ -600,21 +586,16 @@ func generateAzure() (*Azure, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
jwk, err := generateJSONWebKey() jwk, err := generateJSONWebKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Azure{ p := &Azure{
Type: "Azure", Type: "Azure",
Name: name, Name: name,
TenantID: tenantID, TenantID: tenantID,
Audience: azureDefaultAudience, Audience: azureDefaultAudience,
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
claimer: claimer,
config: newAzureConfig(tenantID), config: newAzureConfig(tenantID),
oidcConfig: openIDConfiguration{ oidcConfig: openIDConfiguration{
Issuer: "https://sts.windows.net/" + tenantID + "/", Issuer: "https://sts.windows.net/" + tenantID + "/",
@ -624,7 +605,11 @@ func generateAzure() (*Azure, error) {
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
}, nil }
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
} }
func generateAzureWithServer() (*Azure, *httptest.Server, error) { func generateAzureWithServer() (*Azure, *httptest.Server, error) {
@ -671,7 +656,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
w.Header().Add("Cache-Control", "max-age=5") w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(az.keyStore.keySet)) writeJSON(w, getPublic(az.keyStore.keySet))
case "/metadata/identity/oauth2/token": case "/metadata/identity/oauth2/token":
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0]) tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &az.keyStore.keySet.Keys[0])
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} else { } else {
@ -1009,7 +994,7 @@ func generateAWSToken(p *AWS, sub, iss, aud, accountID, instanceID, privateIP, r
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner( sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
@ -1017,6 +1002,12 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
if err != nil { if err != nil {
return "", err return "", err
} }
var xmsMirID string
if resourceType == "vm" {
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, resourceName)
} else if resourceType == "uai" {
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscriptionID, resourceGroup, resourceName)
}
claims := azurePayload{ claims := azurePayload{
Claims: jose.Claims{ Claims: jose.Claims{
@ -1034,7 +1025,7 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
ObjectID: "the-oid", ObjectID: "the-oid",
TenantID: tenantID, TenantID: tenantID,
Version: "the-version", Version: "the-version",
XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine), XMSMirID: xmsMirID,
} }
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }

View file

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
@ -32,12 +33,11 @@ type X5C struct {
Roots []byte `json:"roots"` Roots []byte `json:"roots"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer ctl *Controller
audiences Audiences
rootPool *x509.CertPool rootPool *x509.CertPool
x509Policy x509PolicyEngine x509Policy policy.X509Policy
sshHostPolicy *hostPolicyEngine sshHostPolicy policy.HostPolicy
sshUserPolicy *userPolicyEngine sshUserPolicy policy.UserPolicy
} }
// GetID returns the provisioner unique identifier. The name and credential id // GetID returns the provisioner unique identifier. The name and credential id
@ -89,7 +89,7 @@ func (p *X5C) GetEncryptedKey() (string, string, bool) {
} }
// Init initializes and validates the fields of a X5C type. // Init initializes and validates the fields of a X5C type.
func (p *X5C) Init(config Config) error { func (p *X5C) Init(config Config) (err error) {
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -104,6 +104,7 @@ func (p *X5C) Init(config Config) error {
var ( var (
block *pem.Block block *pem.Block
rest = p.Roots rest = p.Roots
count int
) )
for rest != nil { for rest != nil {
block, rest = pem.Decode(rest) block, rest = pem.Decode(rest)
@ -114,37 +115,33 @@ func (p *X5C) Init(config Config) error {
if err != nil { if err != nil {
return errors.Wrap(err, "error parsing x509 certificate from PEM block") return errors.Wrap(err, "error parsing x509 certificate from PEM block")
} }
count++
p.rootPool.AddCert(cert) p.rootPool.AddCert(cert)
} }
// Verify that at least one root was found. // Verify that at least one root was found.
if len(p.rootPool.Subjects()) == 0 { if count == 0 {
return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName()) return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName())
} }
// Update claims with global ones
var err error
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = newX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for user certificates // Initialize the SSH allow/deny policy engine for user certificates
if p.sshUserPolicy, err = newSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil { if p.sshUserPolicy, err = policy.NewSSHUserPolicyEngine(p.Options.GetSSHOptions()); err != nil {
return err return err
} }
// Initialize the SSH allow/deny policy engine for host certificates // Initialize the SSH allow/deny policy engine for host certificates
if p.sshHostPolicy, err = newSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil { if p.sshHostPolicy, err = policy.NewSSHHostPolicyEngine(p.Options.GetSSHOptions()); err != nil {
return err return err
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
return nil p.ctl, err = NewController(p, p.Claims, config)
return
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -207,13 +204,13 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error { func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
} }
@ -245,32 +242,31 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeX5C, p.Name, ""), newProvisionerExtensionOption(TypeX5C, p.Name, ""),
profileLimitDuration{p.claimer.DefaultTLSCertDuration(), profileLimitDuration{
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter}, p.ctl.Claimer.DefaultTLSCertDuration(),
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter,
},
// validators // validators
commonNameValidator(claims.Subject), commonNameValidator(claims.Subject),
defaultSANsValidator(claims.SANs), defaultSANsValidator(claims.SANs),
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
}, nil }, nil
} }
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
} }
@ -333,11 +329,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
return append(signOptions, return append(signOptions,
// Checks the validity bounds, and set the validity if has not been set. // Checks the validity bounds, and set the validity if has not been set.
&sshLimitDuration{p.claimer, claims.chains[0][0].NotAfter}, &sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter},
// Validate public key. // Validate public key.
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -2,6 +2,7 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509"
"net/http" "net/http"
"testing" "testing"
"time" "time"
@ -69,7 +70,7 @@ func TestX5C_Init(t *testing.T) {
}, },
"fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo"), audiences: testAudiences}, p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")},
err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"), err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"),
} }
}, },
@ -117,6 +118,8 @@ M46l92gdOozT
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: p, p: p,
extraValid: func(p *X5C) error { extraValid: func(p *X5C) error {
// nolint:staticcheck // We don't have a different way to
// check the number of certificates in the pool.
numCerts := len(p.rootPool.Subjects()) numCerts := len(p.rootPool.Subjects())
if numCerts != 2 { if numCerts != 2 {
return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts) return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts)
@ -141,7 +144,7 @@ M46l92gdOozT
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, tc.p.audiences, config.Audiences.WithFragment(tc.p.GetID())) assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID()))
if tc.extraValid != nil { if tc.extraValid != nil {
assert.Nil(t, tc.extraValid(tc.p)) assert.Nil(t, tc.extraValid(tc.p))
} }
@ -468,13 +471,13 @@ func TestX5C_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeX5C)) assert.Equals(t, v.Type, TypeX5C)
assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "") assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileLimitDuration: case profileLimitDuration:
assert.Equals(t, v.def, tc.p.claimer.DefaultTLSCertDuration()) assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration())
claims, err := tc.p.authorizeToken(tc.token, tc.p.audiences.Sign) claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign)
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter)
case commonNameValidator: case commonNameValidator:
@ -483,8 +486,8 @@ func TestX5C_AuthorizeSign(t *testing.T) {
case defaultSANsValidator: case defaultSANsValidator:
assert.Equals(t, []string(v), tc.sans) assert.Equals(t, []string(v), tc.sans)
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine) assert.Equals(t, nil, v.policyEngine)
default: default:
@ -552,6 +555,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
} }
func TestX5C_AuthorizeRenew(t *testing.T) { func TestX5C_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct { type test struct {
p *X5C p *X5C
code int code int
@ -564,12 +568,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()), err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -583,7 +587,10 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil { if err := tc.p.AuthorizeRenew(context.Background(), &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
@ -619,7 +626,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
enable := false enable := false
p.Claims = &Claims{EnableSSHCA: &enable} p.Claims = &Claims{EnableSSHCA: &enable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
@ -775,10 +782,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
case sshCertDefaultsModifier: case sshCertDefaultsModifier:
assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert}) assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert})
case *sshLimitDuration: case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
case *sshCertValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshNamePolicyValidator: case *sshNamePolicyValidator:
assert.Equals(t, nil, v.userPolicyEngine) assert.Equals(t, nil, v.userPolicyEngine)
assert.Equals(t, nil, v.hostPolicyEngine) assert.Equals(t, nil, v.hostPolicyEngine)

View file

@ -12,6 +12,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/step" "go.step.sm/cli-utils/step"
@ -108,7 +109,9 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner.
UserKeys: sshKeys.UserKeys, UserKeys: sshKeys.UserKeys,
HostKeys: sshKeys.HostKeys, HostKeys: sshKeys.HostKeys,
}, },
GetIdentityFunc: a.getIdentityFunc, GetIdentityFunc: a.getIdentityFunc,
AuthorizeRenewFunc: a.authorizeRenewFunc,
AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc,
}, nil }, nil
} }
@ -395,6 +398,58 @@ func optionsToCertificates(p *linkedca.Provisioner) *provisioner.Options {
ops.SSH.Template = string(p.SshTemplate.Template) ops.SSH.Template = string(p.SshTemplate.Template)
ops.SSH.TemplateData = p.SshTemplate.Data ops.SSH.TemplateData = p.SshTemplate.Data
} }
if p.Policy != nil {
if p.Policy.X509 != nil {
if p.Policy.X509.Allow != nil {
ops.X509.AllowedNames = &policy.X509NameOptions{
DNSDomains: p.Policy.X509.Allow.Dns,
IPRanges: p.Policy.X509.Allow.Ips,
EmailAddresses: p.Policy.X509.Allow.Emails,
URIDomains: p.Policy.X509.Allow.Uris,
}
}
if p.Policy.X509.Deny != nil {
ops.X509.DeniedNames = &policy.X509NameOptions{
DNSDomains: p.Policy.X509.Deny.Dns,
IPRanges: p.Policy.X509.Deny.Ips,
EmailAddresses: p.Policy.X509.Deny.Emails,
URIDomains: p.Policy.X509.Deny.Uris,
}
}
}
if p.Policy.Ssh != nil {
if p.Policy.Ssh.Host != nil {
ops.SSH.Host = &policy.SSHHostCertificateOptions{}
if p.Policy.Ssh.Host.Allow != nil {
ops.SSH.Host.AllowedNames = &policy.SSHNameOptions{
DNSDomains: p.Policy.Ssh.Host.Allow.Dns,
IPRanges: p.Policy.Ssh.Host.Allow.Ips,
}
}
if p.Policy.Ssh.Host.Deny != nil {
ops.SSH.Host.DeniedNames = &policy.SSHNameOptions{
DNSDomains: p.Policy.Ssh.Host.Deny.Dns,
IPRanges: p.Policy.Ssh.Host.Deny.Ips,
}
}
}
if p.Policy.Ssh.User != nil {
ops.SSH.User = &policy.SSHUserCertificateOptions{}
if p.Policy.Ssh.User.Allow != nil {
ops.SSH.User.AllowedNames = &policy.SSHNameOptions{
EmailAddresses: p.Policy.Ssh.User.Allow.Emails,
Principals: p.Policy.Ssh.User.Allow.Principals,
}
}
if p.Policy.Ssh.User.Deny != nil {
ops.SSH.User.DeniedNames = &policy.SSHNameOptions{
EmailAddresses: p.Policy.Ssh.User.Deny.Emails,
Principals: p.Policy.Ssh.User.Deny.Principals,
}
}
}
}
}
return ops return ops
} }
@ -435,7 +490,8 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
} }
pc := &provisioner.Claims{ pc := &provisioner.Claims{
DisableRenewal: &c.DisableRenewal, DisableRenewal: &c.DisableRenewal,
AllowRenewAfterExpiry: &c.AllowRenewAfterExpiry,
} }
var err error var err error
@ -473,12 +529,18 @@ func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims {
} }
disableRenewal := config.DefaultDisableRenewal disableRenewal := config.DefaultDisableRenewal
allowRenewAfterExpiry := config.DefaultAllowRenewAfterExpiry
if c.DisableRenewal != nil { if c.DisableRenewal != nil {
disableRenewal = *c.DisableRenewal disableRenewal = *c.DisableRenewal
} }
if c.AllowRenewAfterExpiry != nil {
allowRenewAfterExpiry = *c.AllowRenewAfterExpiry
}
lc := &linkedca.Claims{ lc := &linkedca.Claims{
DisableRenewal: disableRenewal, DisableRenewal: disableRenewal,
AllowRenewAfterExpiry: allowRenewAfterExpiry,
} }
if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil { if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil {

View file

@ -9,6 +9,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
@ -241,6 +242,45 @@ func (a *Authority) SignSSH(ctx context.Context, key ssh.PublicKey, opts provisi
return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", certTpl.CertType) return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", certTpl.CertType)
} }
switch certTpl.CertType {
case ssh.UserCert:
// when no user policy engine is configured, but a host policy engine is
// configured, the user certificate is denied.
if a.sshUserPolicy == nil && a.sshHostPolicy != nil {
return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign ssh user certificates"), "authority.SignSSH: error creating ssh user certificate")
}
if a.sshUserPolicy != nil {
allowed, err := a.sshUserPolicy.ArePrincipalsAllowed(certTpl)
if err != nil {
return nil, errs.InternalServerErr(err,
errs.WithMessage("authority.SignSSH: error creating ssh user certificate"),
)
}
if !allowed {
return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign"), "authority.SignSSH: error creating ssh user certificate")
}
}
case ssh.HostCert:
// when no host policy engine is configured, but a user policy engine is
// configured, the host certificate is denied.
if a.sshHostPolicy == nil && a.sshUserPolicy != nil {
return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign ssh host certificates"), "authority.SignSSH: error creating ssh user certificate")
}
if a.sshHostPolicy != nil {
allowed, err := a.sshHostPolicy.ArePrincipalsAllowed(certTpl)
if err != nil {
return nil, errs.InternalServerErr(err,
errs.WithMessage("authority.SignSSH: error creating ssh host certificate"),
)
}
if !allowed {
return nil, errs.ForbiddenErr(errors.New("authority not allowed to sign"), "authority.SignSSH: error creating ssh host certificate")
}
}
default:
return nil, errs.InternalServer("authority.SignSSH: unexpected ssh certificate type: %d", certTpl.CertType)
}
// Sign certificate. // Sign certificate.
cert, err := sshutil.CreateCertificate(certTpl, signer) cert, err := sshutil.CreateCertificate(certTpl, signer)
if err != nil { if err != nil {

View file

@ -191,6 +191,22 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
} }
} }
// Check if authority is allowed to sign the certificate
var allowedToSign bool
if allowedToSign, err = a.isAllowedToSign(leaf); err != nil {
return nil, errs.InternalServerErr(err,
errs.WithKeyVal("csr", csr),
errs.WithKeyVal("signOptions", signOpts),
errs.WithMessage("error creating certificate"),
)
}
if !allowedToSign {
return nil, errs.ApplyOptions(
errs.ForbiddenErr(errors.New("authority not allowed to sign"), "error creating certificate"),
opts...,
)
}
// Sign certificate // Sign certificate
lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate)) lifetime := leaf.NotAfter.Sub(leaf.NotBefore.Add(signOpts.Backdate))
resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{ resp, err := a.x509CAService.CreateCertificate(&casapi.CreateCertificateRequest{
@ -214,6 +230,21 @@ func (a *Authority) Sign(csr *x509.CertificateRequest, signOpts provisioner.Sign
return fullchain, nil return fullchain, nil
} }
// isAllowedToSign checks if the Authority is allowed to sign the X.509 certificate.
// It first checks if the certificate contains an admin subject that exists in the
// collection of admins. The CA is always allowed to sign those. If the cert contains
// different names and a policy is configured, the policy will be executed against
// the cert to see if the CA is allowed to sign it.
func (a *Authority) isAllowedToSign(cert *x509.Certificate) (bool, error) {
// if no policy is configured, the cert is implicitly allowed
if a.x509Policy == nil {
return true, nil
}
return a.x509Policy.AreCertificateNamesAllowed(cert)
}
// Renew creates a new Certificate identical to the old certificate, except // Renew creates a new Certificate identical to the old certificate, except
// with a validity window that begins 'now'. // with a validity window that begins 'now'.
func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) { func (a *Authority) Renew(oldCert *x509.Certificate) ([]*x509.Certificate, error) {

View file

@ -757,7 +757,7 @@ func TestAuthority_Renew(t *testing.T) {
now := time.Now().UTC() now := time.Now().UTC()
nb1 := now.Add(-time.Minute * 7) nb1 := now.Add(-time.Minute * 7)
na1 := now na1 := now.Add(time.Hour)
so := &provisioner.SignOptions{ so := &provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(nb1), NotBefore: provisioner.NewTimeDuration(nb1),
NotAfter: provisioner.NewTimeDuration(na1), NotAfter: provisioner.NewTimeDuration(na1),
@ -798,7 +798,20 @@ func TestAuthority_Renew(t *testing.T) {
"fail/unauthorized": func() (*renewTest, error) { "fail/unauthorized": func() (*renewTest, error) {
return &renewTest{ return &renewTest{
cert: certNoRenew, cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized,
}, nil
},
"fail/WithAuthorizeRenewFunc": func() (*renewTest, error) {
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
return errs.Unauthorized("not authorized")
}))
aa.x509CAService = a.x509CAService
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
return &renewTest{
auth: aa,
cert: cert,
err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
}, nil }, nil
}, },
@ -820,6 +833,17 @@ func TestAuthority_Renew(t *testing.T) {
cert: cert, cert: cert,
}, nil }, nil
}, },
"ok/WithAuthorizeRenewFunc": func() (*renewTest, error) {
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
return nil
}))
aa.x509CAService = a.x509CAService
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
return &renewTest{
auth: aa,
cert: cert,
}, nil
},
} }
for name, genTestCase := range tests { for name, genTestCase := range tests {
@ -856,7 +880,7 @@ func TestAuthority_Renew(t *testing.T) {
expiry := now.Add(time.Minute * 7) expiry := now.Add(time.Minute * 7)
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, leaf.Subject.String(), assert.Equals(t, leaf.Subject.String(),
@ -956,7 +980,7 @@ func TestAuthority_Rekey(t *testing.T) {
now := time.Now().UTC() now := time.Now().UTC()
nb1 := now.Add(-time.Minute * 7) nb1 := now.Add(-time.Minute * 7)
na1 := now na1 := now.Add(time.Hour)
so := &provisioner.SignOptions{ so := &provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(nb1), NotBefore: provisioner.NewTimeDuration(nb1),
NotAfter: provisioner.NewTimeDuration(na1), NotAfter: provisioner.NewTimeDuration(na1),
@ -998,7 +1022,7 @@ func TestAuthority_Rekey(t *testing.T) {
"fail/unauthorized": func() (*renewTest, error) { "fail/unauthorized": func() (*renewTest, error) {
return &renewTest{ return &renewTest{
cert: certNoRenew, cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
}, nil }, nil
}, },
@ -1063,7 +1087,7 @@ func TestAuthority_Rekey(t *testing.T) {
expiry := now.Add(time.Minute * 7) expiry := now.Add(time.Minute * 7)
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, leaf.Subject.String(), assert.Equals(t, leaf.Subject.String(),

View file

@ -408,6 +408,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
server.ServeTLS(listener, "", "") server.ServeTLS(listener, "", "")
}() }()
defer server.Close() defer server.Close()
time.Sleep(1 * time.Second)
// Create bootstrap client // Create bootstrap client
token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
@ -419,7 +420,6 @@ func TestBootstrapClientServerRotation(t *testing.T) {
// doTest does a request that requires mTLS // doTest does a request that requires mTLS
doTest := func(client *http.Client) error { doTest := func(client *http.Client) error {
time.Sleep(1 * time.Second)
// test with ca // test with ca
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody) resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
if err != nil { if err != nil {

View file

@ -208,7 +208,8 @@ func (ca *CA) Init(cfg *config.Config) (*CA, error) {
adminDB := auth.GetAdminDatabase() adminDB := auth.GetAdminDatabase()
if adminDB != nil { if adminDB != nil {
acmeAdminResponder := adminAPI.NewACMEAdminResponder() acmeAdminResponder := adminAPI.NewACMEAdminResponder()
adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder) policyAdminResponder := adminAPI.NewPolicyAdminResponder(auth, adminDB)
adminHandler := adminAPI.NewHandler(auth, adminDB, acmeDB, acmeAdminResponder, policyAdminResponder)
mux.Route("/admin", func(r chi.Router) { mux.Route("/admin", func(r chi.Router) {
adminHandler.Route(r) adminHandler.Route(r)
}) })
@ -450,9 +451,6 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
tlsConfig.ClientCAs = certPool tlsConfig.ClientCAs = certPool
// Use server's most preferred ciphersuite
tlsConfig.PreferServerCipherSuites = true
return tlsConfig, nil return tlsConfig, nil
} }

View file

@ -563,6 +563,11 @@ func (c *Client) retryOnError(r *http.Response) bool {
return false return false
} }
// GetCaURL returns the configured CA url.
func (c *Client) GetCaURL() string {
return c.endpoint.String()
}
// GetRootCAs returns the RootCAs certificate pool from the configured // GetRootCAs returns the RootCAs certificate pool from the configured
// transport. // transport.
func (c *Client) GetRootCAs() *x509.CertPool { func (c *Client) GetRootCAs() *x509.CertPool {
@ -723,6 +728,36 @@ retry:
return &sign, nil return &sign, nil
} }
// RenewWithToken performs the renew request to the CA with the given
// authorization token and returns the api.SignResponse struct. This method is
// generally used to renew an expired certificate.
func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
req, err := http.NewRequest("POST", u.String(), http.NoBody)
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error creating request")
}
req.Header.Add("Authorization", "Bearer "+token)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error reading %s", u)
}
return &sign, nil
}
// Rekey performs the rekey request to the CA and returns the api.SignResponse // Rekey performs the rekey request to the CA and returns the api.SignResponse
// struct. // struct.
func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {

View file

@ -17,13 +17,16 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/x509util"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
) )
const ( const (
@ -354,7 +357,7 @@ func TestClient_Sign(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.SignRequest) body := new(api.SignRequest)
if err := api.ReadJSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e) api.WriteError(w, e)
@ -426,7 +429,7 @@ func TestClient_Revoke(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.RevokeRequest) body := new(api.RevokeRequest)
if err := api.ReadJSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e) api.WriteError(w, e)
@ -529,6 +532,74 @@ func TestClient_Renew(t *testing.T) {
} }
} }
func TestClient_RenewWithToken(t *testing.T) {
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)},
{Certificate: parseCertificate(rootPEM)},
},
}
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Authorization") != "Bearer token" {
api.JSONStatus(w, errs.InternalServer("force"), 500)
} else {
api.JSONStatus(w, tt.response, tt.responseCode)
}
})
got, err := c.RenewWithToken("token")
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.RenewWithToken() = %v, want nil", got)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response)
}
}
})
}
}
func TestClient_Rekey(t *testing.T) { func TestClient_Rekey(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
@ -1060,3 +1131,28 @@ func TestClient_SSHBastion(t *testing.T) {
}) })
} }
} }
func TestClient_GetCaURL(t *testing.T) {
tests := []struct {
name string
caURL string
want string
}{
{"ok", "https://ca.com", "https://ca.com"},
{"ok no schema", "ca.com", "https://ca.com"},
{"ok with port", "https://ca.com:9000", "https://ca.com:9000"},
{"ok with version", "https://ca.com/1.0", "https://ca.com/1.0"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
if got := c.GetCaURL(); got != tt.want {
t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -8,6 +8,7 @@ import (
"net/url" "net/url"
"os" "os"
"reflect" "reflect"
"sort"
"testing" "testing"
) )
@ -196,7 +197,7 @@ func TestLoadClient(t *testing.T) {
switch { switch {
case gotTransport.TLSClientConfig.GetClientCertificate == nil: case gotTransport.TLSClientConfig.GetClientCertificate == nil:
t.Error("LoadClient() transport does not define GetClientCertificate") t.Error("LoadClient() transport does not define GetClientCertificate")
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()): case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !equalPools(gotTransport.TLSClientConfig.RootCAs, wantTransport.TLSClientConfig.RootCAs):
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
default: default:
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil) crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
@ -238,3 +239,23 @@ func Test_defaultsConfig_Validate(t *testing.T) {
}) })
} }
} }
// nolint:staticcheck,gocritic
func equalPools(a, b *x509.CertPool) bool {
if reflect.DeepEqual(a, b) {
return true
}
subjects := a.Subjects()
sA := make([]string, len(subjects))
for i := range subjects {
sA[i] = string(subjects[i])
}
subjects = b.Subjects()
sB := make([]string, len(subjects))
for i := range subjects {
sB[i] = string(subjects[i])
}
sort.Strings(sA)
sort.Strings(sB)
return reflect.DeepEqual(sA, sB)
}

View file

@ -346,6 +346,8 @@ func TestIdentity_GetCertPool(t *testing.T) {
return return
} }
if got != nil { if got != nil {
// nolint:staticcheck // we don't have a different way to check
// the certificates in the pool.
subjects := got.Subjects() subjects := got.Subjects()
if !reflect.DeepEqual(subjects, tt.wantSubjects) { if !reflect.DeepEqual(subjects, tt.wantSubjects) {
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects) t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)

2
ca/testdata/ca.json vendored
View file

@ -6,7 +6,7 @@
"password": "password", "password": "password",
"address": "127.0.0.1:0", "address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"], "dnsNames": ["127.0.0.1"],
"logger": {"format": "text"}, "_logger": {"format": "text"},
"tls": { "tls": {
"minVersion": 1.2, "minVersion": 1.2,
"maxVersion": 1.3, "maxVersion": 1.3,

View file

@ -6,7 +6,7 @@
"password": "asdf", "password": "asdf",
"address": "127.0.0.1:0", "address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"], "dnsNames": ["127.0.0.1"],
"logger": {"format": "text"}, "_logger": {"format": "text"},
"tls": { "tls": {
"minVersion": 1.2, "minVersion": 1.2,
"maxVersion": 1.2, "maxVersion": 1.2,

Some files were not shown because too many files have changed in this diff Show more