Merge branch 'master' into herman/allow-deny-next

This commit is contained in:
Herman Slatman 2022-03-24 12:36:12 +01:00
commit dc23fd23bf
No known key found for this signature in database
GPG key ID: F4D8A44EA0A75A4F
85 changed files with 2533 additions and 911 deletions

View file

@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: [ '1.15', '1.16', '1.17' ] go: [ '1.17', '1.18' ]
outputs: outputs:
is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }} is_prerelease: ${{ steps.is_prerelease.outputs.IS_PRERELEASE }}
steps: steps:
@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.44.0' version: 'v1.45.0'
# Optional: working directory, useful for monorepos # Optional: working directory, useful for monorepos
# working-directory: somedir # working-directory: somedir
@ -106,7 +106,7 @@ jobs:
name: Set up Go name: Set up Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: 1.17 go-version: 1.18
- -
name: APT Install name: APT Install
id: aptInstall id: aptInstall
@ -159,7 +159,7 @@ jobs:
name: Setup Go name: Setup Go
uses: actions/setup-go@v2 uses: actions/setup-go@v2
with: with:
go-version: '1.17' go-version: '1.18'
- -
name: Install cosign name: Install cosign
uses: sigstore/cosign-installer@v1.1.0 uses: sigstore/cosign-installer@v1.1.0

View file

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
strategy: strategy:
matrix: matrix:
go: [ '1.16', '1.17' ] go: [ '1.17', '1.18' ]
steps: steps:
- -
name: Checkout name: Checkout
@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.44.0' version: 'v1.45.0'
# Optional: working directory, useful for monorepos # Optional: working directory, useful for monorepos
# working-directory: somedir # working-directory: somedir
@ -58,7 +58,7 @@ jobs:
run: V=1 make ci run: V=1 make ci
- -
name: Codecov name: Codecov
if: matrix.go == '1.17' if: matrix.go == '1.18'
uses: codecov/codecov-action@v1.2.1 uses: codecov/codecov-action@v1.2.1
with: with:
file: ./coverage.out # optional file: ./coverage.out # optional

1
.gitignore vendored
View file

@ -19,3 +19,4 @@ coverage.txt
output output
vendor vendor
.idea .idea
.envrc

View file

@ -19,6 +19,7 @@ builds:
- linux_386 - linux_386
- linux_amd64 - linux_amd64
- linux_arm64 - linux_arm64
- linux_arm_5
- linux_arm_6 - linux_arm_6
- linux_arm_7 - linux_arm_7
- windows_amd64 - windows_amd64
@ -39,6 +40,7 @@ builds:
- linux_386 - linux_386
- linux_amd64 - linux_amd64
- linux_arm64 - linux_arm64
- linux_arm_5
- linux_arm_6 - linux_arm_6
- linux_arm_7 - linux_arm_7
- windows_amd64 - windows_amd64
@ -59,6 +61,7 @@ builds:
- linux_386 - linux_386
- linux_amd64 - linux_amd64
- linux_arm64 - linux_arm64
- linux_arm_5
- linux_arm_6 - linux_arm_6
- linux_arm_7 - linux_arm_7
- windows_amd64 - windows_amd64

View file

@ -6,7 +6,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [Unreleased - 0.18.3] - DATE ## [Unreleased - 0.18.3] - DATE
### Added ### Added
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`.
### Changed ### Changed
- Made SCEP CA URL paths dynamic
- Support two latest versions of Go (1.17, 1.18)
### Deprecated ### Deprecated
### Removed ### Removed
### Fixed ### Fixed
@ -15,6 +18,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
## [0.18.2] - 2022-03-01 ## [0.18.2] - 2022-03-01
### Added ### Added
- Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner. - Added `subscriptionIDs` and `objectIDs` filters to the Azure provisioner.
- [NoSQL](https://github.com/smallstep/nosql/pull/21) package allows filtering
out database drivers using Go tags. For example, using the Go flag
`--tags=nobadger,nobbolt,nomysql` will only compile `step-ca` with the pgx
driver for PostgreSQL.
### Changed ### Changed
- IPv6 addresses are normalized as IP addresses instead of hostnames. - IPv6 addresses are normalized as IP addresses instead of hostnames.
- More descriptive JWK decryption error message. - More descriptive JWK decryption error message.

View file

@ -20,6 +20,7 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -33,6 +34,7 @@ type Authority interface {
// context specifies the Authorize[Sign|Revoke|etc.] method. // context specifies the Authorize[Sign|Revoke|etc.] method.
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
AuthorizeSign(ott string) ([]provisioner.SignOption, error) AuthorizeSign(ott string) ([]provisioner.SignOption, error)
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
GetTLSOptions() *config.TLSOptions GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error) Root(shasum string) (*x509.Certificate, error)
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
@ -43,7 +45,7 @@ type Authority interface {
GetProvisioners(cursor string, limit int) (provisioner.List, string, error) GetProvisioners(cursor string, limit int) (provisioner.List, string, error)
Revoke(context.Context, *authority.RevokeOptions) error Revoke(context.Context, *authority.RevokeOptions) error
GetEncryptedKey(kid string) (string, error) GetEncryptedKey(kid string) (string, error)
GetRoots() (federation []*x509.Certificate, err error) GetRoots() ([]*x509.Certificate, error)
GetFederation() ([]*x509.Certificate, error) GetFederation() ([]*x509.Certificate, error)
Version() authority.Version Version() authority.Version
} }

View file

@ -13,6 +13,7 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
@ -27,14 +28,17 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/x509util"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"go.step.sm/crypto/jose"
"golang.org/x/crypto/ssh"
) )
const ( const (
@ -171,6 +175,7 @@ type mockAuthority struct {
ret1, ret2 interface{} ret1, ret2 interface{}
err error err error
authorizeSign func(ott string) ([]provisioner.SignOption, error) authorizeSign func(ott string) ([]provisioner.SignOption, error)
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
getTLSOptions func() *authority.TLSOptions getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error) root func(shasum string) (*x509.Certificate, error)
sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) sign func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
@ -208,6 +213,13 @@ func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, err
return m.ret1.([]provisioner.SignOption), m.err return m.ret1.([]provisioner.SignOption), m.err
} }
func (m *mockAuthority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
if m.authorizeRenewToken != nil {
return m.authorizeRenewToken(ctx, ott)
}
return m.ret1.(*x509.Certificate), m.err
}
func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions { func (m *mockAuthority) GetTLSOptions() *authority.TLSOptions {
if m.getTLSOptions != nil { if m.getTLSOptions != nil {
return m.getTLSOptions() return m.getTLSOptions()
@ -920,48 +932,141 @@ func Test_caHandler_Renew(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
// Prepare root and leaf for renew after expiry test.
now := time.Now()
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
leafPub, leafPriv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
root := &x509.Certificate{
Subject: pkix.Name{CommonName: "Test Root CA"},
PublicKey: rootPub,
KeyUsage: x509.KeyUsageCertSign,
BasicConstraintsValid: true,
IsCA: true,
NotBefore: now.Add(-2 * time.Hour),
NotAfter: now.Add(time.Hour),
}
root, err = x509util.CreateCertificate(root, root, rootPub, rootPriv)
if err != nil {
t.Fatal(err)
}
expiredLeaf := &x509.Certificate{
Subject: pkix.Name{CommonName: "Leaf certificate"},
PublicKey: leafPub,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
EmailAddresses: []string{"test@example.org"},
}
expiredLeaf, err = x509util.CreateCertificate(expiredLeaf, root, leafPub, rootPriv)
if err != nil {
t.Fatal(err)
}
// Generate renew after expiry token
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("x5cInsecure", []string{base64.StdEncoding.EncodeToString(expiredLeaf.Raw)})
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: leafPriv}, so)
if err != nil {
t.Fatal(err)
}
generateX5cToken := func(claims jose.Claims) string {
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
if err != nil {
t.Fatal(err)
}
return s
}
tests := []struct { tests := []struct {
name string name string
tls *tls.ConnectionState tls *tls.ConnectionState
header http.Header
cert *x509.Certificate cert *x509.Certificate
root *x509.Certificate root *x509.Certificate
err error err error
statusCode int statusCode int
}{ }{
{"ok", cs, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated}, {"ok", cs, nil, parseCertificate(certPEM), parseCertificate(rootPEM), nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, http.StatusBadRequest}, {"ok renew after expiry", &tls.ConnectionState{}, http.Header{
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, http.StatusBadRequest}, "Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
{"renew error", cs, nil, nil, errs.Forbidden("an error"), http.StatusForbidden}, NotBefore: jose.NewNumericDate(now), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
})},
}, expiredLeaf, root, nil, http.StatusCreated},
{"no tls", nil, nil, nil, nil, nil, http.StatusBadRequest},
{"no peer certificates", &tls.ConnectionState{}, nil, nil, nil, nil, http.StatusBadRequest},
{"renew error", cs, nil, nil, nil, errs.Forbidden("an error"), http.StatusForbidden},
{"fail expired token", &tls.ConnectionState{}, http.Header{
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
})},
}, expiredLeaf, root, errs.Forbidden("an error"), http.StatusUnauthorized},
{"fail invalid root", &tls.ConnectionState{}, http.Header{
"Authorization": []string{"Bearer " + generateX5cToken(jose.Claims{
NotBefore: jose.NewNumericDate(now.Add(-time.Hour)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
})},
}, expiredLeaf, parseCertificate(rootPEM), errs.Forbidden("an error"), http.StatusUnauthorized},
} }
expected := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ h := New(&mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err, ret1: tt.cert, ret2: tt.root, err: tt.err,
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
if err != nil {
return nil, errs.Unauthorized(err.Error())
}
var claims jose.Claims
if err := jwt.Claims(chain[0][0].PublicKey, &claims); err != nil {
return nil, errs.Unauthorized(err.Error())
}
if err := claims.ValidateWithLeeway(jose.Expected{
Time: now,
}, time.Minute); err != nil {
return nil, errs.Unauthorized(err.Error())
}
return chain[0][0], nil
},
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) }).(*caHandler)
req := httptest.NewRequest("POST", "http://example.com/renew", nil) req := httptest.NewRequest("POST", "http://example.com/renew", nil)
req.TLS = tt.tls req.TLS = tt.tls
req.Header = tt.header
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Renew(logging.NewResponseLogger(w), req) h.Renew(logging.NewResponseLogger(w), req)
res := w.Result()
if res.StatusCode != tt.statusCode { res := w.Result()
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode) defer res.Body.Close()
}
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
res.Body.Close()
if err != nil { if err != nil {
t.Errorf("caHandler.Renew unexpected error = %v", err) t.Errorf("caHandler.Renew unexpected error = %v", err)
} }
if res.StatusCode != tt.statusCode {
t.Errorf("caHandler.Renew StatusCode = %d, wants %d", res.StatusCode, tt.statusCode)
t.Errorf("%s", body)
}
if tt.statusCode < http.StatusBadRequest { if tt.statusCode < http.StatusBadRequest {
expected := []byte(`{"crt":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `",` +
`"ca":"` + strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `",` +
`"certChain":["` +
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.cert.Raw})), "\n", `\n`) + `","` +
strings.ReplaceAll(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: tt.root.Raw})), "\n", `\n`) + `"]}`)
if !bytes.Equal(bytes.TrimSpace(body), expected) { if !bytes.Equal(bytes.TrimSpace(body), expected) {
t.Errorf("caHandler.Root Body = %s, wants %s", body, expected) t.Errorf("caHandler.Root Body = \n%s, wants \n%s", body, expected)
} }
} }
}) })

View file

@ -7,7 +7,9 @@ import (
"os" "os"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
@ -59,6 +61,6 @@ func WriteError(w http.ResponseWriter, err error) {
} }
if err := json.NewEncoder(w).Encode(err); err != nil { if err := json.NewEncoder(w).Encode(err); err != nil {
LogError(w, err) log.Error(w, err)
} }
} }

47
api/log/log.go Normal file
View file

@ -0,0 +1,47 @@
// Package log implements API-related logging helpers.
package log
import (
"log"
"net/http"
"github.com/smallstep/certificates/logging"
)
// Error adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package.
func Error(rw http.ResponseWriter, err error) {
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
} else {
log.Println(err)
}
}
// EnabledResponse log the response object if it implements the EnableLogger
// interface.
func EnabledResponse(rw http.ResponseWriter, v interface{}) {
type enableLogger interface {
ToLog() (interface{}, error)
}
if el, ok := v.(enableLogger); ok {
out, err := el.ToLog()
if err != nil {
Error(rw, err)
return
}
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"response": out,
})
} else {
log.Println(out)
}
}
}

44
api/log/log_test.go Normal file
View file

@ -0,0 +1,44 @@
package log
import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/smallstep/certificates/logging"
)
func TestError(t *testing.T) {
theError := errors.New("the error")
type args struct {
rw http.ResponseWriter
err error
}
tests := []struct {
name string
args args
withFields bool
}{
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Error(tt.args.rw, tt.args.err)
if tt.withFields {
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
fields := rl.Fields()
if !reflect.DeepEqual(fields["error"], theError) {
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
}
} else {
t.Error("ResponseWriter does not implement logging.ResponseLogger")
}
}
})
}
}

31
api/read/read.go Normal file
View file

@ -0,0 +1,31 @@
// Package read implements request object readers.
package read
import (
"encoding/json"
"io"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/errs"
)
// JSON reads JSON from the request body and stores it in the value
// pointed by v.
func JSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
}
return nil
}
// ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
}
return protojson.Unmarshal(data, m)
}

46
api/read/read_test.go Normal file
View file

@ -0,0 +1,46 @@
package read
import (
"io"
"reflect"
"strings"
"testing"
"github.com/smallstep/certificates/errs"
)
func TestJSON(t *testing.T) {
type args struct {
r io.Reader
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := JSON(tt.args.r, &tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("JSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
e, ok := err.(*errs.Error)
if ok {
if code := e.StatusCode(); code != 400 {
t.Errorf("error.StatusCode() = %v, wants 400", code)
}
} else {
t.Errorf("error type = %T, wants *Error", err)
}
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("JSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
}
})
}
}

View file

@ -3,6 +3,7 @@ package api
import ( import (
"net/http" "net/http"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -32,7 +33,7 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
} }
var body RekeyRequest var body RekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -1,20 +1,28 @@
package api package api
import ( import (
"crypto/x509"
"net/http" "net/http"
"strings"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
const (
authorizationHeader = "Authorization"
bearerScheme = "Bearer"
)
// Renew uses the information of certificate in the TLS connection to create a // Renew uses the information of certificate in the TLS connection to create a
// new one. // new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { cert, err := h.getPeerCertificate(r)
WriteError(w, errs.BadRequest("missing client certificate")) if err != nil {
WriteError(w, err)
return return
} }
certChain, err := h.Authority.Renew(r.TLS.PeerCertificates[0]) certChain, err := h.Authority.Renew(cert)
if err != nil { if err != nil {
WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) WriteError(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return return
@ -33,3 +41,15 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
TLSOptions: h.Authority.GetTLSOptions(), TLSOptions: h.Authority.GetTLSOptions(),
}, http.StatusCreated) }, http.StatusCreated)
} }
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
return r.TLS.PeerCertificates[0], nil
}
if s := r.Header.Get(authorizationHeader); s != "" {
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1])
}
}
return nil, errs.BadRequest("missing client certificate")
}

View file

@ -4,11 +4,13 @@ import (
"context" "context"
"net/http" "net/http"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ocsp"
) )
// RevokeResponse is the response object that returns the health of the server. // RevokeResponse is the response object that returns the health of the server.
@ -48,7 +50,7 @@ func (r *RevokeRequest) Validate() (err error) {
// TODO: Add CRL and OCSP support. // TODO: Add CRL and OCSP support.
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest var body RevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -13,6 +13,7 @@ import (
"testing" "testing"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
@ -49,7 +50,7 @@ type SignResponse struct {
// information in the certificate request. // information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest var body SignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -9,12 +9,14 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"golang.org/x/crypto/ssh"
) )
// SSHAuthority is the interface implemented by a SSH CA authority. // SSHAuthority is the interface implemented by a SSH CA authority.
@ -249,7 +251,7 @@ type SSHBastionResponse struct {
// the request. // the request.
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest var body SSHSignRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
@ -393,7 +395,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
// and servers. // and servers.
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest var body SSHConfigRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
@ -425,7 +427,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest var body SSHCheckPrincipalRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }
@ -464,7 +466,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
// SSHBastion provides returns the bastion configured if any. // SSHBastion provides returns the bastion configured if any.
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest var body SSHBastionRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -4,9 +4,11 @@ import (
"net/http" "net/http"
"time" "time"
"golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
) )
// SSHRekeyRequest is the request body of an SSH certificate request. // SSHRekeyRequest is the request body of an SSH certificate request.
@ -38,7 +40,7 @@ type SSHRekeyResponse struct {
// the request. // the request.
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest var body SSHRekeyRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -6,6 +6,8 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -36,7 +38,7 @@ type SSHRenewResponse struct {
// the request. // the request.
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest var body SSHRenewRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -3,11 +3,13 @@ package api
import ( import (
"net/http" "net/http"
"golang.org/x/crypto/ocsp"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"golang.org/x/crypto/ocsp"
) )
// SSHRevokeResponse is the response object that returns the health of the server. // SSHRevokeResponse is the response object that returns the health of the server.
@ -47,7 +49,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
// NOTE: currently only Passive revocation is supported. // NOTE: currently only Passive revocation is supported.
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest var body SSHRevokeRequest
if err := ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
WriteError(w, errs.BadRequestErr(err, "error reading request body")) WriteError(w, errs.BadRequestErr(err, "error reading request body"))
return return
} }

View file

@ -18,12 +18,13 @@ import (
"testing" "testing"
"time" "time"
"golang.org/x/crypto/ssh"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"golang.org/x/crypto/ssh"
) )
var ( var (

View file

@ -4,14 +4,13 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io"
"log"
"net/http" "net/http"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/api/log"
"github.com/smallstep/certificates/logging"
) )
// EnableLogger is an interface that enables response logging for an object. // EnableLogger is an interface that enables response logging for an object.
@ -19,38 +18,6 @@ type EnableLogger interface {
ToLog() (interface{}, error) ToLog() (interface{}, error)
} }
// LogError adds to the response writer the given error if it implements
// logging.ResponseLogger. If it does not implement it, then writes the error
// using the log package.
func LogError(rw http.ResponseWriter, err error) {
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"error": err,
})
} else {
log.Println(err)
}
}
// LogEnabledResponse log the response object if it implements the EnableLogger
// interface.
func LogEnabledResponse(rw http.ResponseWriter, v interface{}) {
if el, ok := v.(EnableLogger); ok {
out, err := el.ToLog()
if err != nil {
LogError(rw, err)
return
}
if rl, ok := rw.(logging.ResponseLogger); ok {
rl.WithFields(map[string]interface{}{
"response": out,
})
} else {
log.Println(out)
}
}
}
// JSON writes the passed value into the http.ResponseWriter. // JSON writes the passed value into the http.ResponseWriter.
func JSON(w http.ResponseWriter, v interface{}) { func JSON(w http.ResponseWriter, v interface{}) {
JSONStatus(w, v, http.StatusOK) JSONStatus(w, v, http.StatusOK)
@ -62,17 +29,19 @@ func JSONStatus(w http.ResponseWriter, v interface{}, status int) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status) w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(v); err != nil { if err := json.NewEncoder(w).Encode(v); err != nil {
LogError(w, err) log.Error(w, err)
return return
} }
LogEnabledResponse(w, v)
log.EnabledResponse(w, v)
} }
// JSONNotFound writes a HTTP Not Found response with empty body. // JSONNotFound writes a HTTP Not Found response with empty body.
func JSONNotFound(w http.ResponseWriter) { func JSONNotFound(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
LogEnabledResponse(w, nil) log.EnabledResponse(w, nil)
} }
// ProtoJSON writes the passed value into the http.ResponseWriter. // ProtoJSON writes the passed value into the http.ResponseWriter.
@ -88,33 +57,18 @@ func ProtoJSONStatus(w http.ResponseWriter, m proto.Message, status int) {
b, err := protojson.Marshal(m) b, err := protojson.Marshal(m)
if err != nil { if err != nil {
LogError(w, err) log.Error(w, err)
return return
} }
if _, err := w.Write(b); err != nil { if _, err := w.Write(b); err != nil {
LogError(w, err) log.Error(w, err)
return return
} }
//LogEnabledResponse(w, v)
}
// ReadJSON reads JSON from the request body and stores it in the value // log.EnabledResponse(w, v)
// pointed by v.
func ReadJSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json")
}
return nil
}
// ReadProtoJSON reads JSON from the request body and stores it in the value
// pointed by v.
func ReadProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r)
if err != nil {
return errs.BadRequestErr(err, "error reading request body")
}
return protojson.Unmarshal(data, m)
} }
// ReadProtoJSONWithCheck reads JSON from the request body and stores it in the value // ReadProtoJSONWithCheck reads JSON from the request body and stores it in the value

View file

@ -1,50 +1,13 @@
package api package api
import ( import (
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"strings"
"testing" "testing"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"github.com/smallstep/certificates/logging" "github.com/smallstep/certificates/logging"
) )
func TestLogError(t *testing.T) {
theError := errors.New("the error")
type args struct {
rw http.ResponseWriter
err error
}
tests := []struct {
name string
args args
withFields bool
}{
{"normalLogger", args{httptest.NewRecorder(), theError}, false},
{"responseLogger", args{logging.NewResponseLogger(httptest.NewRecorder()), theError}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
LogError(tt.args.rw, tt.args.err)
if tt.withFields {
if rl, ok := tt.args.rw.(logging.ResponseLogger); ok {
fields := rl.Fields()
if !reflect.DeepEqual(fields["error"], theError) {
t.Errorf("ResponseLogger[\"error\"] = %s, wants %s", fields["error"], theError)
}
} else {
t.Error("ResponseWriter does not implement logging.ResponseLogger")
}
}
})
}
}
func TestJSON(t *testing.T) { func TestJSON(t *testing.T) {
type args struct { type args struct {
rw http.ResponseWriter rw http.ResponseWriter
@ -88,38 +51,3 @@ func TestJSON(t *testing.T) {
}) })
} }
} }
func TestReadJSON(t *testing.T) {
type args struct {
r io.Reader
v interface{}
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{strings.NewReader(`{"foo":"bar"}`), make(map[string]interface{})}, false},
{"fail", args{strings.NewReader(`{"foo"}`), make(map[string]interface{})}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ReadJSON(tt.args.r, &tt.args.v)
if (err != nil) != tt.wantErr {
t.Errorf("ReadJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
e, ok := err.(*errs.Error)
if ok {
if code := e.StatusCode(); code != 400 {
t.Errorf("error.StatusCode() = %v, wants 400", code)
}
} else {
t.Errorf("error type = %T, wants *Error", err)
}
} else if !reflect.DeepEqual(tt.args.v, map[string]interface{}{"foo": "bar"}) {
t.Errorf("ReadJSON value = %v, wants %v", tt.args.v, map[string]interface{}{"foo": "bar"})
}
})
}
}

View file

@ -5,10 +5,13 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
) )
type adminAuthority interface { type adminAuthority interface {
@ -116,7 +119,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
// CreateAdmin creates a new admin. // CreateAdmin creates a new admin.
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
var body CreateAdminRequest var body CreateAdminRequest
if err := api.ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return return
} }
@ -160,7 +163,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
// UpdateAdmin updates an existing admin. // UpdateAdmin updates an existing admin.
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
var body UpdateAdminRequest var body UpdateAdminRequest
if err := api.ReadJSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) api.WriteError(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
return return
} }

View file

@ -4,12 +4,14 @@ import (
"net/http" "net/http"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/linkedca"
) )
// GetProvisionersResponse is the type for GET /admin/provisioners responses. // GetProvisionersResponse is the type for GET /admin/provisioners responses.
@ -72,7 +74,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
// CreateProvisioner creates a new prov. // CreateProvisioner creates a new prov.
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
var prov = new(linkedca.Provisioner) var prov = new(linkedca.Provisioner)
if err := api.ReadProtoJSON(r.Body, prov); err != nil { if err := read.ProtoJSON(r.Body, prov); err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }
@ -122,7 +124,7 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
// UpdateProvisioner updates an existing prov. // UpdateProvisioner updates an existing prov.
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
var nu = new(linkedca.Provisioner) var nu = new(linkedca.Provisioner)
if err := api.ReadProtoJSON(r.Body, nu); err != nil { if err := read.ProtoJSON(r.Body, nu); err != nil {
api.WriteError(w, err) api.WriteError(w, err)
return return
} }

View file

@ -71,10 +71,12 @@ type Authority struct {
startTime time.Time startTime time.Time
// Custom functions // Custom functions
sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error) sshBastionFunc func(ctx context.Context, user, hostname string) (*config.Bastion, error)
sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error) sshCheckHostFunc func(ctx context.Context, principal string, tok string, roots []*x509.Certificate) (bool, error)
sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error) sshGetHostsFunc func(ctx context.Context, cert *x509.Certificate) ([]config.Host, error)
getIdentityFunc provisioner.GetIdentityFunc getIdentityFunc provisioner.GetIdentityFunc
authorizeRenewFunc provisioner.AuthorizeRenewFunc
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
// Policy engines // Policy engines
x509Policy policy.X509Policy x509Policy policy.X509Policy

View file

@ -6,6 +6,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"net/http" "net/http"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -276,6 +277,7 @@ func (a *Authority) authorizeRevoke(ctx context.Context, token string) error {
func (a *Authority) authorizeRenew(cert *x509.Certificate) error { func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
serial := cert.SerialNumber.String() serial := cert.SerialNumber.String()
var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)} var opts = []interface{}{errs.WithKeyVal("serialNumber", serial)}
isRevoked, err := a.IsRevoked(serial) isRevoked, err := a.IsRevoked(serial)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
@ -283,7 +285,6 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
if isRevoked { if isRevoked {
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...) return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
} }
p, ok := a.provisioners.LoadByCertificate(cert) p, ok := a.provisioners.LoadByCertificate(cert)
if !ok { if !ok {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
@ -371,3 +372,80 @@ func (a *Authority) authorizeSSHRevoke(ctx context.Context, token string) error
} }
return nil return nil
} }
// AuthorizeRenewToken validates the renew token and returns the leaf
// certificate in the x5cInsecure header.
func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) {
var claims jose.Claims
jwt, chain, err := jose.ParseX5cInsecure(ott, a.rootX509Certs)
if err != nil {
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
}
leaf := chain[0][0]
if err := jwt.Claims(leaf.PublicKey, &claims); err != nil {
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token"))
}
p, ok := a.provisioners.LoadByCertificate(leaf)
if !ok {
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
}
if err := a.UseToken(ott, p); err != nil {
return nil, err
}
if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: p.GetName(),
Subject: leaf.Subject.CommonName,
Time: time.Now().UTC(),
}, time.Minute); err != nil {
switch err {
case jose.ErrInvalidIssuer:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid issuer claim (iss)"))
case jose.ErrInvalidSubject:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: invalid subject claim (sub)"))
case jose.ErrNotValidYet:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token not valid yet (nbf)"))
case jose.ErrExpired:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token is expired (exp)"))
case jose.ErrIssuedInTheFuture:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token: token issued in the future (iat)"))
default:
return nil, errs.UnauthorizedErr(err, errs.WithMessage("error validating renew token"))
}
}
audiences := a.config.GetAudiences().Renew
if !matchesAudience(claims.Audience, audiences) {
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
}
return leaf, nil
}
// matchesAudience returns true if A and B share at least one element.
func matchesAudience(as, bs []string) bool {
if len(bs) == 0 || len(as) == 0 {
return false
}
for _, b := range bs {
for _, a := range as {
if b == a || stripPort(a) == stripPort(b) {
return true
}
}
}
return false
}
// stripPort attempts to strip the port from the given url. If parsing the url
// produces errors it will just return the passed argument.
func stripPort(rawurl string) string {
u, err := url.Parse(rawurl)
if err != nil {
return rawurl
}
u.Host = u.Hostname()
return u.String()
}

View file

@ -3,11 +3,15 @@ package authority
import ( import (
"context" "context"
"crypto" "crypto"
"crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"reflect"
"strconv" "strconv"
"testing" "testing"
"time" "time"
@ -20,6 +24,7 @@ import (
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/randutil" "go.step.sm/crypto/randutil"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@ -753,6 +758,7 @@ func TestAuthority_Authorize(t *testing.T) {
func TestAuthority_authorizeRenew(t *testing.T) { func TestAuthority_authorizeRenew(t *testing.T) {
fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt") fooCrt, err := pemutil.ReadCertificate("testdata/certs/foo.crt")
fooCrt.NotAfter = time.Now().Add(time.Hour)
assert.FatalError(t, err) assert.FatalError(t, err)
renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt") renewDisabledCrt, err := pemutil.ReadCertificate("testdata/certs/renew-disabled.crt")
@ -822,7 +828,7 @@ func TestAuthority_authorizeRenew(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
cert: renewDisabledCrt, cert: renewDisabledCrt,
err: errors.New("authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'renew_disabled'"), err: errors.New("authority.authorizeRenew: renew is disabled for provisioner 'renew_disabled'"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -909,6 +915,7 @@ func generateSSHToken(sub, iss, aud string, iat time.Time, sshOpts *provisioner.
} }
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
now := time.Now()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -917,6 +924,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if cert.ValidAfter == 0 {
cert.ValidAfter = uint64(now.Unix())
}
if cert.ValidBefore == 0 {
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
}
if err := cert.SignCert(rand.Reader, signer); err != nil { if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -1003,6 +1016,23 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
} }
func TestAuthority_authorizeSSHRenew(t *testing.T) { func TestAuthority_authorizeSSHRenew(t *testing.T) {
now := time.Now().UTC()
sshpop := func(a *Authority) (*ssh.Certificate, string) {
p, ok := a.provisioners.Load("sshpop/sshpop")
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key")
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
cert, jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
assert.FatalError(t, err)
token, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop", []string{"foo.smallstep.com"}, now, jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return cert, token
}
a := testAuthority(t) a := testAuthority(t)
jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass"))) jwk, err := jose.ReadKey("testdata/secrets/step_cli_key_priv.jwk", jose.WithPassword([]byte("pass")))
@ -1012,8 +1042,6 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
(&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID)) (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", jwk.KeyID))
assert.FatalError(t, err) assert.FatalError(t, err)
now := time.Now().UTC()
validIssuer := "step-cli" validIssuer := "step-cli"
type authorizeTest struct { type authorizeTest struct {
@ -1050,27 +1078,34 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
"fail/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
return errs.Forbidden("forbidden")
}))
_, token := sshpop(aa)
return &authorizeTest{
auth: aa,
token: token,
err: errors.New("authority.authorizeSSHRenew: forbidden"),
code: http.StatusForbidden,
}
},
"ok": func(t *testing.T) *authorizeTest { "ok": func(t *testing.T) *authorizeTest {
key, err := pemutil.Read("./testdata/secrets/ssh_host_ca_key") cert, token := sshpop(a)
assert.FatalError(t, err)
signer, ok := key.(crypto.Signer)
assert.Fatal(t, ok, "could not cast ssh signing key to crypto signer")
sshSigner, err := ssh.NewSignerFromSigner(signer)
assert.FatalError(t, err)
cert, _jwk, err := createSSHCert(&ssh.Certificate{CertType: ssh.HostCert}, sshSigner)
assert.FatalError(t, err)
p, ok := a.provisioners.Load("sshpop/sshpop")
assert.Fatal(t, ok, "sshpop provisioner not found in test authority")
tok, err := generateToken("foo", p.GetName(), testAudiences.SSHRenew[0]+"#sshpop/sshpop",
[]string{"foo.smallstep.com"}, now, _jwk, withSSHPOPFile(cert))
assert.FatalError(t, err)
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: tok, token: token,
cert: cert,
}
},
"ok/WithAuthorizeSSHRenewFunc": func(t *testing.T) *authorizeTest {
aa := testAuthority(t, WithAuthorizeSSHRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error {
return nil
}))
cert, token := sshpop(aa)
return &authorizeTest{
auth: aa,
token: token,
cert: cert, cert: cert,
} }
}, },
@ -1290,3 +1325,283 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
}) })
} }
} }
func TestAuthority_AuthorizeRenewToken(t *testing.T) {
ctx := context.Background()
type stepProvisionerASN1 struct {
Type int
Name []byte
CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"`
}
_, signer, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
csr, err := x509util.CreateCertificateRequest("test.example.com", []string{"test.example.com"}, signer)
if err != nil {
t.Fatal(err)
}
_, otherSigner, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
generateX5cToken := func(a *Authority, key crypto.Signer, claims jose.Claims, opts ...provisioner.SignOption) (string, *x509.Certificate) {
chain, err := a.Sign(csr, provisioner.SignOptions{}, opts...)
if err != nil {
t.Fatal(err)
}
var x5c []string
for _, c := range chain {
x5c = append(x5c, base64.StdEncoding.EncodeToString(c.Raw))
}
so := new(jose.SignerOptions)
so.WithType("JWT")
so.WithHeader("x5cInsecure", x5c)
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.EdDSA, Key: key}, so)
if err != nil {
t.Fatal(err)
}
s, err := jose.Signed(sig).Claims(claims).CompactSerialize()
if err != nil {
t.Fatal(err)
}
return s, chain[0]
}
now := time.Now()
a1 := testAuthority(t)
t1, c1 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
t2, c2 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now.Add(-time.Hour)
cert.NotAfter = now.Add(-time.Minute)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("foobar"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badIssuer, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "bad-issuer",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSubject, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "bad-subject",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badAudience, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com",
Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
type args struct {
ctx context.Context
ott string
}
tests := []struct {
name string
authority *Authority
args args
want *x509.Certificate
wantErr bool
}{
{"ok", a1, args{ctx, t1}, c1, false},
{"ok expired cert", a1, args{ctx, t2}, c2, false},
{"fail token", a1, args{ctx, "not.a.token"}, nil, true},
{"fail token reuse", a1, args{ctx, t1}, nil, true},
{"fail token signature", a1, args{ctx, badSigner}, nil, true},
{"fail token provisioner", a1, args{ctx, badProvisioner}, nil, true},
{"fail token iss", a1, args{ctx, badIssuer}, nil, true},
{"fail token sub", a1, args{ctx, badSubject}, nil, true},
{"fail token iat", a1, args{ctx, badNotBefore}, nil, true},
{"fail token iat", a1, args{ctx, badExpiry}, nil, true},
{"fail token iat", a1, args{ctx, badIssuedAt}, nil, true},
{"fail token aud", a1, args{ctx, badAudience}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.authority.AuthorizeRenewToken(tt.args.ctx, tt.args.ott)
if (err != nil) != tt.wantErr {
t.Errorf("Authority.AuthorizeRenewToken() error = %+v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Authority.AuthorizeRenewToken() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -27,23 +27,27 @@ var (
DefaultBackdate = time.Minute DefaultBackdate = time.Minute
// DefaultDisableRenewal disables renewals per provisioner. // DefaultDisableRenewal disables renewals per provisioner.
DefaultDisableRenewal = false DefaultDisableRenewal = false
// DefaultAllowRenewAfterExpiry allows renewals even if the certificate is
// expired.
DefaultAllowRenewAfterExpiry = false
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally // DefaultEnableSSHCA enable SSH CA features per provisioner or globally
// for all provisioners. // for all provisioners.
DefaultEnableSSHCA = false DefaultEnableSSHCA = false
// GlobalProvisionerClaims default claims for the Authority. Can be overridden // GlobalProvisionerClaims default claims for the Authority. Can be overridden
// by provisioner specific claims. // by provisioner specific claims.
GlobalProvisionerClaims = provisioner.Claims{ GlobalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &DefaultDisableRenewal, MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, EnableSSHCA: &DefaultEnableSSHCA,
EnableSSHCA: &DefaultEnableSSHCA, DisableRenewal: &DefaultDisableRenewal,
AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry,
} }
) )
@ -271,28 +275,32 @@ func (c *Config) GetAudiences() provisioner.Audiences {
} }
for _, name := range c.DNSNames { for _, name := range c.DNSNames {
hostname := toHostname(name)
audiences.Sign = append(audiences.Sign, audiences.Sign = append(audiences.Sign,
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/sign", hostname),
fmt.Sprintf("https://%s/sign", toHostname(name)), fmt.Sprintf("https://%s/sign", hostname),
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
fmt.Sprintf("https://%s/ssh/sign", toHostname(name))) fmt.Sprintf("https://%s/ssh/sign", hostname))
audiences.Renew = append(audiences.Renew,
fmt.Sprintf("https://%s/1.0/renew", hostname),
fmt.Sprintf("https://%s/renew", hostname))
audiences.Revoke = append(audiences.Revoke, audiences.Revoke = append(audiences.Revoke,
fmt.Sprintf("https://%s/1.0/revoke", toHostname(name)), fmt.Sprintf("https://%s/1.0/revoke", hostname),
fmt.Sprintf("https://%s/revoke", toHostname(name))) fmt.Sprintf("https://%s/revoke", hostname))
audiences.SSHSign = append(audiences.SSHSign, audiences.SSHSign = append(audiences.SSHSign,
fmt.Sprintf("https://%s/1.0/ssh/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/sign", hostname),
fmt.Sprintf("https://%s/ssh/sign", toHostname(name)), fmt.Sprintf("https://%s/ssh/sign", hostname),
fmt.Sprintf("https://%s/1.0/sign", toHostname(name)), fmt.Sprintf("https://%s/1.0/sign", hostname),
fmt.Sprintf("https://%s/sign", toHostname(name))) fmt.Sprintf("https://%s/sign", hostname))
audiences.SSHRevoke = append(audiences.SSHRevoke, audiences.SSHRevoke = append(audiences.SSHRevoke,
fmt.Sprintf("https://%s/1.0/ssh/revoke", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/revoke", hostname),
fmt.Sprintf("https://%s/ssh/revoke", toHostname(name))) fmt.Sprintf("https://%s/ssh/revoke", hostname))
audiences.SSHRenew = append(audiences.SSHRenew, audiences.SSHRenew = append(audiences.SSHRenew,
fmt.Sprintf("https://%s/1.0/ssh/renew", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/renew", hostname),
fmt.Sprintf("https://%s/ssh/renew", toHostname(name))) fmt.Sprintf("https://%s/ssh/renew", hostname))
audiences.SSHRekey = append(audiences.SSHRekey, audiences.SSHRekey = append(audiences.SSHRekey,
fmt.Sprintf("https://%s/1.0/ssh/rekey", toHostname(name)), fmt.Sprintf("https://%s/1.0/ssh/rekey", hostname),
fmt.Sprintf("https://%s/ssh/rekey", toHostname(name))) fmt.Sprintf("https://%s/ssh/rekey", hostname))
} }
return audiences return audiences

View file

@ -92,6 +92,24 @@ func WithGetIdentityFunc(fn func(ctx context.Context, p provisioner.Interface, e
} }
} }
// WithAuthorizeRenewFunc sets a custom function that authorizes the renewal of
// an X.509 certificate.
func WithAuthorizeRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error) Option {
return func(a *Authority) error {
a.authorizeRenewFunc = fn
return nil
}
}
// WithAuthorizeSSHRenewFunc sets a custom function that authorizes the renewal
// of a SSH certificate.
func WithAuthorizeSSHRenewFunc(fn func(ctx context.Context, p *provisioner.Controller, cert *ssh.Certificate) error) Option {
return func(a *Authority) error {
a.authorizeSSHRenewFunc = fn
return nil
}
}
// WithSSHBastionFunc sets a custom function to get the bastion for a // WithSSHBastionFunc sets a custom function to get the bastion for a
// given user-host pair. // given user-host pair.
func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option { func WithSSHBastionFunc(fn func(ctx context.Context, user, host string) (*config.Bastion, error)) Option {

View file

@ -7,8 +7,8 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
) )
// ACME is the acme provisioner type, an entity that can authorize the ACME // ACME is the acme provisioner type, an entity that can authorize the ACME
@ -26,7 +26,8 @@ type ACME struct {
RequireEAB bool `json:"requireEAB,omitempty"` RequireEAB bool `json:"requireEAB,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
ctl *Controller
x509Policy policy.X509Policy x509Policy policy.X509Policy
} }
@ -72,7 +73,7 @@ func (p *ACME) GetOptions() *Options {
// DefaultTLSCertDuration returns the default TLS cert duration enforced by // DefaultTLSCertDuration returns the default TLS cert duration enforced by
// the provisioner. // the provisioner.
func (p *ACME) DefaultTLSCertDuration() time.Duration { func (p *ACME) DefaultTLSCertDuration() time.Duration {
return p.claimer.DefaultTLSCertDuration() return p.ctl.Claimer.DefaultTLSCertDuration()
} }
// Init initializes and validates the fields of an ACME type. // Init initializes and validates the fields of an ACME type.
@ -84,17 +85,13 @@ func (p *ACME) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty") return errors.New("provisioner name cannot be empty")
} }
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
} }
return nil p.ctl, err = NewController(p, p.Claims, config)
return
} }
// ACMEIdentifierType encodes ACME Identifier types // ACMEIdentifierType encodes ACME Identifier types
@ -142,10 +139,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeACME, p.Name, ""), newProvisionerExtensionOption(TypeACME, p.Name, ""),
newForceCNOption(p.ForceCN), newForceCNOption(p.ForceCN),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
} }
@ -166,8 +163,5 @@ func (p *ACME) AuthorizeRevoke(ctx context.Context, token string) error {
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *ACME) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName())
}
return nil
} }

View file

@ -91,6 +91,7 @@ func TestACME_Init(t *testing.T) {
} }
func TestACME_AuthorizeRenew(t *testing.T) { func TestACME_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct { type test struct {
p *ACME p *ACME
cert *x509.Certificate cert *x509.Certificate
@ -104,21 +105,27 @@ func TestACME_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("acme.AuthorizeRenew; renew is disabled for acme provisioner '%s'", p.GetName()), err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateACME() p, err := generateACME()
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
} }
}, },
} }
@ -172,18 +179,18 @@ func TestACME_AuthorizeSign(t *testing.T) {
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeACME)) assert.Equals(t, v.Type, TypeACME)
assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "") assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case *forceCNOption: case *forceCNOption:
assert.Equals(t, v.ForceCN, tc.p.ForceCN) assert.Equals(t, v.ForceCN, tc.p.ForceCN)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine) assert.Equals(t, nil, v.policyEngine)
default: default:

View file

@ -265,9 +265,9 @@ type AWS struct {
IIDRoots string `json:"iidRoots,omitempty"` IIDRoots string `json:"iidRoots,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
config *awsConfig config *awsConfig
audiences Audiences audiences Audiences
ctl *Controller
x509Policy policy.X509Policy x509Policy policy.X509Policy
sshHostPolicy policy.HostPolicy sshHostPolicy policy.HostPolicy
} }
@ -403,15 +403,11 @@ func (p *AWS) Init(config Config) (err error) {
case p.InstanceAge.Value() < 0: case p.InstanceAge.Value() < 0:
return errors.New("provisioner instanceAge cannot be negative") return errors.New("provisioner instanceAge cannot be negative")
} }
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Add default config // Add default config
if p.config, err = newAWSConfig(p.IIDRoots); err != nil { if p.config, err = newAWSConfig(p.IIDRoots); err != nil {
return err return err
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
// validate IMDS versions // validate IMDS versions
if len(p.IMDSVersions) == 0 { if len(p.IMDSVersions) == 0 {
@ -438,7 +434,9 @@ func (p *AWS) Init(config Config) (err error) {
return err return err
} }
return nil config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return
} }
// AuthorizeSign validates the given token and returns the sign options that // AuthorizeSign validates the given token and returns the sign options that
@ -486,11 +484,11 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID), newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
commonNameValidator(payload.Claims.Subject), commonNameValidator(payload.Claims.Subject),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
), nil ), nil
} }
@ -500,10 +498,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *AWS) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("aws.AuthorizeRenew; renew is disabled for aws provisioner '%s'", p.GetName())
}
return nil
} }
// assertConfig initializes the config if it has not been initialized // assertConfig initializes the config if it has not been initialized
@ -678,7 +673,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(payload.Audience, p.audiences.Sign) { if !matchesAudience(payload.Audience, p.ctl.Audiences.Sign) {
return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)") return nil, errs.Unauthorized("aws.authorizeToken; invalid token - invalid audience claim (aud)")
} }
@ -718,7 +713,7 @@ func (p *AWS) authorizeToken(token string) (*awsPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("aws.AuthorizeSSHSign; ssh ca is disabled for aws provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
@ -766,11 +761,11 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -677,18 +677,18 @@ func TestAWS_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAWS)) assert.Equals(t, v.Type, TypeAWS)
assert.Equals(t, v.Name, tt.aws.GetName()) assert.Equals(t, v.Name, tt.aws.GetName())
assert.Equals(t, v.CredentialID, tt.aws.Accounts[0]) assert.Equals(t, v.CredentialID, tt.aws.Accounts[0])
assert.Len(t, 2, v.KeyValuePairs) assert.Len(t, 2, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.aws.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.aws.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator: case commonNameValidator:
assert.Equals(t, string(v), tt.args.cn) assert.Equals(t, string(v), tt.args.cn)
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.aws.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.aws.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.aws.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.aws.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator: case ipAddressesValidator:
assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")}) assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")})
case emailAddressesValidator: case emailAddressesValidator:
@ -728,7 +728,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p3.Claims = &Claims{EnableSSHCA: &disable} p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com") t1, err := p1.GetIdentityToken("127.0.0.1", "https://ca.smallstep.com")
@ -749,7 +749,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{ expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"}, CertType: "host", Principals: []string{"127.0.0.1", "ip-127-0-0-1.us-west-1.compute.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
@ -826,6 +826,7 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
} }
func TestAWS_AuthorizeRenew(t *testing.T) { func TestAWS_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateAWS() p1, err := generateAWS()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateAWS() p2, err := generateAWS()
@ -834,7 +835,7 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -847,8 +848,14 @@ func TestAWS_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -13,11 +13,13 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
) )
// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens. // azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
@ -31,7 +33,7 @@ const azureDefaultAudience = "https://management.azure.com/"
// azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim. // azureXMSMirIDRegExp is the regular expression used to parse the xms_mirid claim.
// Using case insensitive as resourceGroups appears as resourcegroups. // Using case insensitive as resourceGroups appears as resourcegroups.
var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Compute/virtualMachines/([^/]+)$`) var azureXMSMirIDRegExp = regexp.MustCompile(`(?i)^/subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.(Compute/virtualMachines|ManagedIdentity/userAssignedIdentities)/([^/]+)$`)
type azureConfig struct { type azureConfig struct {
oidcDiscoveryURL string oidcDiscoveryURL string
@ -97,10 +99,10 @@ type Azure struct {
DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"` DisableTrustOnFirstUse bool `json:"disableTrustOnFirstUse"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
config *azureConfig config *azureConfig
oidcConfig openIDConfiguration oidcConfig openIDConfiguration
keyStore *keyStore keyStore *keyStore
ctl *Controller
x509Policy policy.X509Policy x509Policy policy.X509Policy
sshHostPolicy policy.HostPolicy sshHostPolicy policy.HostPolicy
} }
@ -206,24 +208,20 @@ func (p *Azure) Init(config Config) (err error) {
case p.Audience == "": // use default audience case p.Audience == "": // use default audience
p.Audience = azureDefaultAudience p.Audience = azureDefaultAudience
} }
// Initialize config // Initialize config
p.assertConfig() p.assertConfig()
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Decode and validate openid-configuration endpoint // Decode and validate openid-configuration endpoint
if err := getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil { if err = getAndDecode(p.config.oidcDiscoveryURL, &p.oidcConfig); err != nil {
return err return
} }
if err := p.oidcConfig.Validate(); err != nil { if err := p.oidcConfig.Validate(); err != nil {
return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL) return errors.Wrapf(err, "error parsing %s", p.config.oidcDiscoveryURL)
} }
// Get JWK key set // Get JWK key set
if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil { if p.keyStore, err = newKeyStore(p.oidcConfig.JWKSetURI); err != nil {
return err return
} }
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
@ -236,7 +234,8 @@ func (p *Azure) Init(config Config) (err error) {
return err return err
} }
return nil p.ctl, err = NewController(p, p.Claims, config)
return
} }
// authorizeToken returns the claims, name, group, subscription, identityObjectID, error. // authorizeToken returns the claims, name, group, subscription, identityObjectID, error.
@ -276,11 +275,14 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, str
} }
re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID) re := azureXMSMirIDRegExp.FindStringSubmatch(claims.XMSMirID)
if len(re) != 4 { if len(re) != 5 {
return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID) return nil, "", "", "", "", errs.Unauthorized("azure.authorizeToken; error parsing xms_mirid claim - %s", claims.XMSMirID)
} }
var subscription, group, name string
identityObjectID := claims.ObjectID identityObjectID := claims.ObjectID
subscription, group, name := re[1], re[2], re[3] subscription, group, name = re[1], re[2], re[4]
return &claims, name, group, subscription, identityObjectID, nil return &claims, name, group, subscription, identityObjectID, nil
} }
@ -368,10 +370,10 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID), newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
), nil ), nil
} }
@ -381,15 +383,12 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *Azure) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("azure.AuthorizeRenew; renew is disabled for azure provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("azure.AuthorizeSSHSign; sshCA is disabled for provisioner '%s'", p.GetName())
} }
@ -434,11 +433,11 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -95,7 +95,7 @@ func TestAzure_GetIdentityToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, t1, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
@ -237,7 +237,7 @@ func TestAzure_authorizeToken(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), jwk) time.Now(), jwk)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -252,7 +252,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, tok, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -267,7 +267,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
"foo", "subscriptionID", "resourceGroup", "virtualMachine", "foo", "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -321,7 +321,7 @@ func TestAzure_authorizeToken(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
defer srv.Close() defer srv.Close()
tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience, tok, err := generateAzureToken("subject", p.oidcConfig.Issuer, azureDefaultAudience,
p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p.keyStore.keySet.Keys[0]) time.Now(), &p.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
@ -437,28 +437,28 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, t11, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience, failIssuer, err := generateAzureToken("subject", "bad-issuer", azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience", failAudience, err := generateAzureToken("subject", p1.oidcConfig.Issuer, "bad-audience",
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), &p1.keyStore.keySet.Keys[0]) time.Now(), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failExp, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0]) time.Now().Add(-360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failNbf, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0]) time.Now().Add(360*time.Second), &p1.keyStore.keySet.Keys[0])
assert.FatalError(t, err) assert.FatalError(t, err)
failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience, failKey, err := generateAzureToken("subject", p1.oidcConfig.Issuer, azureDefaultAudience,
p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", p1.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm",
time.Now(), badKey) time.Now(), badKey)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -506,18 +506,18 @@ func TestAzure_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeAzure)) assert.Equals(t, v.Type, TypeAzure)
assert.Equals(t, v.Name, tt.azure.GetName()) assert.Equals(t, v.Name, tt.azure.GetName())
assert.Equals(t, v.CredentialID, tt.azure.TenantID) assert.Equals(t, v.CredentialID, tt.azure.TenantID)
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.azure.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.azure.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator: case commonNameValidator:
assert.Equals(t, string(v), "virtualMachine") assert.Equals(t, string(v), "virtualMachine")
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.azure.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.azure.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.azure.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.azure.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator: case ipAddressesValidator:
assert.Equals(t, v, nil) assert.Equals(t, v, nil)
case emailAddressesValidator: case emailAddressesValidator:
@ -538,6 +538,7 @@ func TestAzure_AuthorizeSign(t *testing.T) {
} }
func TestAzure_AuthorizeRenew(t *testing.T) { func TestAzure_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateAzure() p1, err := generateAzure()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateAzure() p2, err := generateAzure()
@ -546,7 +547,7 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -559,8 +560,14 @@ func TestAzure_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -597,7 +604,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p3.Claims = &Claims{EnableSSHCA: &disable} p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := p1.GetIdentityToken("subject", "caURL") t1, err := p1.GetIdentityToken("subject", "caURL")
@ -618,7 +625,7 @@ func TestAzure_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{ expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"virtualMachine"}, CertType: "host", Principals: []string{"virtualMachine"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),

View file

@ -10,10 +10,10 @@ import (
// Claims so that individual provisioners can override global claims. // Claims so that individual provisioners can override global claims.
type Claims struct { type Claims struct {
// TLS CA properties // TLS CA properties
MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"` MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"` MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"` DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
DisableRenewal *bool `json:"disableRenewal,omitempty"`
// SSH CA properties // SSH CA properties
MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"` MinUserSSHDur *Duration `json:"minUserSSHCertDuration,omitempty"`
MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"` MaxUserSSHDur *Duration `json:"maxUserSSHCertDuration,omitempty"`
@ -22,6 +22,10 @@ type Claims struct {
MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"` MaxHostSSHDur *Duration `json:"maxHostSSHCertDuration,omitempty"`
DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"` DefaultHostSSHDur *Duration `json:"defaultHostSSHCertDuration,omitempty"`
EnableSSHCA *bool `json:"enableSSHCA,omitempty"` EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
// Renewal properties
DisableRenewal *bool `json:"disableRenewal,omitempty"`
AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"`
} }
// Claimer is the type that controls claims. It provides an interface around the // Claimer is the type that controls claims. It provides an interface around the
@ -40,19 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
// Claims returns the merge of the inner and global claims. // Claims returns the merge of the inner and global claims.
func (c *Claimer) Claims() Claims { func (c *Claimer) Claims() Claims {
disableRenewal := c.IsDisableRenewal() disableRenewal := c.IsDisableRenewal()
allowRenewAfterExpiry := c.AllowRenewAfterExpiry()
enableSSHCA := c.IsSSHCAEnabled() enableSSHCA := c.IsSSHCAEnabled()
return Claims{ return Claims{
MinTLSDur: &Duration{c.MinTLSCertDuration()}, MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
DisableRenewal: &disableRenewal, MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, EnableSSHCA: &enableSSHCA,
EnableSSHCA: &enableSSHCA, DisableRenewal: &disableRenewal,
AllowRenewAfterExpiry: &allowRenewAfterExpiry,
} }
} }
@ -102,6 +109,16 @@ func (c *Claimer) IsDisableRenewal() bool {
return *c.claims.DisableRenewal return *c.claims.DisableRenewal
} }
// AllowRenewAfterExpiry returns if the renewal flow is authorized if the
// certificate is expired. If the property is not set within the provisioner
// then the global value from the authority configuration will be used.
func (c *Claimer) AllowRenewAfterExpiry() bool {
if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil {
return *c.global.AllowRenewAfterExpiry
}
return *c.claims.AllowRenewAfterExpiry
}
// DefaultSSHCertDuration returns the default SSH certificate duration for the // DefaultSSHCertDuration returns the default SSH certificate duration for the
// given certificate type. // given certificate type.
func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) { func (c *Claimer) DefaultSSHCertDuration(certType uint32) (time.Duration, error) {

View file

@ -152,8 +152,8 @@ func (c *Collection) LoadByToken(token *jose.JSONWebToken, claims *jose.Claims)
// proper id to load the provisioner. // proper id to load the provisioner.
func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) { func (c *Collection) LoadByCertificate(cert *x509.Certificate) (Interface, bool) {
for _, e := range cert.Extensions { for _, e := range cert.Extensions {
if e.Id.Equal(stepOIDProvisioner) { if e.Id.Equal(StepOIDProvisioner) {
var provisioner stepProvisionerASN1 var provisioner extensionASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil { if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return nil, false return nil, false
} }

View file

@ -147,6 +147,17 @@ func TestCollection_LoadByToken(t *testing.T) {
} }
func TestCollection_LoadByCertificate(t *testing.T) { func TestCollection_LoadByCertificate(t *testing.T) {
mustExtension := func(typ Type, name, credentialID string) pkix.Extension {
e := Extension{
Type: typ, Name: name, CredentialID: credentialID,
}
ext, err := e.ToExtension()
if err != nil {
t.Fatal(err)
}
return ext
}
p1, err := generateJWK() p1, err := generateJWK()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateOIDC() p2, err := generateOIDC()
@ -159,30 +170,21 @@ func TestCollection_LoadByCertificate(t *testing.T) {
byName.Store(p2.GetName(), p2) byName.Store(p2.GetName(), p2)
byName.Store(p3.GetName(), p3) byName.Store(p3.GetName(), p3)
ok1Ext, err := createProvisionerExtension(1, p1.Name, p1.Key.KeyID)
assert.FatalError(t, err)
ok2Ext, err := createProvisionerExtension(2, p2.Name, p2.ClientID)
assert.FatalError(t, err)
ok3Ext, err := createProvisionerExtension(int(TypeACME), p3.Name, "")
assert.FatalError(t, err)
notFoundExt, err := createProvisionerExtension(1, "foo", "bar")
assert.FatalError(t, err)
ok1Cert := &x509.Certificate{ ok1Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok1Ext}, Extensions: []pkix.Extension{mustExtension(1, p1.Name, p1.Key.KeyID)},
} }
ok2Cert := &x509.Certificate{ ok2Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok2Ext}, Extensions: []pkix.Extension{mustExtension(2, p2.Name, p2.ClientID)},
} }
ok3Cert := &x509.Certificate{ ok3Cert := &x509.Certificate{
Extensions: []pkix.Extension{ok3Ext}, Extensions: []pkix.Extension{mustExtension(TypeACME, p3.Name, "")},
} }
notFoundCert := &x509.Certificate{ notFoundCert := &x509.Certificate{
Extensions: []pkix.Extension{notFoundExt}, Extensions: []pkix.Extension{mustExtension(1, "foo", "bar")},
} }
badCert := &x509.Certificate{ badCert := &x509.Certificate{
Extensions: []pkix.Extension{ Extensions: []pkix.Extension{
{Id: stepOIDProvisioner, Critical: false, Value: []byte("foobar")}, {Id: StepOIDProvisioner, Critical: false, Value: []byte("foobar")},
}, },
} }

View file

@ -0,0 +1,194 @@
package provisioner
import (
"context"
"crypto/x509"
"regexp"
"strings"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"golang.org/x/crypto/ssh"
)
// Controller wraps a provisioner with other attributes useful in callback
// functions.
type Controller struct {
Interface
Audiences *Audiences
Claimer *Claimer
IdentityFunc GetIdentityFunc
AuthorizeRenewFunc AuthorizeRenewFunc
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
}
// NewController initializes a new provisioner controller.
func NewController(p Interface, claims *Claims, config Config) (*Controller, error) {
claimer, err := NewClaimer(claims, config.Claims)
if err != nil {
return nil, err
}
return &Controller{
Interface: p,
Audiences: &config.Audiences,
Claimer: claimer,
IdentityFunc: config.GetIdentityFunc,
AuthorizeRenewFunc: config.AuthorizeRenewFunc,
AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc,
}, nil
}
// GetIdentity returns the identity for a given email.
func (c *Controller) GetIdentity(ctx context.Context, email string) (*Identity, error) {
if c.IdentityFunc != nil {
return c.IdentityFunc(ctx, c.Interface, email)
}
return DefaultIdentityFunc(ctx, c.Interface, email)
}
// AuthorizeRenew returns nil if the given cert can be renewed, returns an error
// otherwise.
func (c *Controller) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if c.AuthorizeRenewFunc != nil {
return c.AuthorizeRenewFunc(ctx, c, cert)
}
return DefaultAuthorizeRenew(ctx, c, cert)
}
// AuthorizeSSHRenew returns nil if the given cert can be renewed, returns an
// error otherwise.
func (c *Controller) AuthorizeSSHRenew(ctx context.Context, cert *ssh.Certificate) error {
if c.AuthorizeSSHRenewFunc != nil {
return c.AuthorizeSSHRenewFunc(ctx, c, cert)
}
return DefaultAuthorizeSSHRenew(ctx, c, cert)
}
// Identity is the type representing an externally supplied identity that is used
// by provisioners to populate certificate fields.
type Identity struct {
Usernames []string `json:"usernames"`
Permissions `json:"permissions"`
}
// GetIdentityFunc is a function that returns an identity.
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
// AuthorizeRenewFunc is a function that returns nil if the renewal of a
// certificate is enabled.
type AuthorizeRenewFunc func(ctx context.Context, p *Controller, cert *x509.Certificate) error
// AuthorizeSSHRenewFunc is a function that returns nil if the renewal of the
// given SSH certificate is enabled.
type AuthorizeSSHRenewFunc func(ctx context.Context, p *Controller, cert *ssh.Certificate) error
// DefaultIdentityFunc return a default identity depending on the provisioner
// type. For OIDC email is always present and the usernames might
// contain empty strings.
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
switch k := p.(type) {
case *OIDC:
// OIDC principals would be:
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
// 2. Sanitized local.
// 3. Raw local (if different).
// 4. Email address.
name := SanitizeSSHUserPrincipal(email)
if !sshUserRegex.MatchString(name) {
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
}
usernames := []string{name}
if i := strings.LastIndex(email, "@"); i >= 0 {
usernames = append(usernames, email[:i])
}
usernames = append(usernames, email)
return &Identity{
Usernames: SanitizeStringSlices(usernames),
}, nil
default:
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
}
}
// DefaultAuthorizeRenew is the default implementation of AuthorizeRenew. It
// will return an error if the provisioner has the renewal disabled, if the
// certificate is not yet valid or if the certificate is expired and renew after
// expiry is disabled.
func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certificate) error {
if p.Claimer.IsDisableRenewal() {
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
}
now := time.Now().Truncate(time.Second)
if now.Before(cert.NotBefore) {
return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano))
}
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}
return nil
}
// DefaultAuthorizeSSHRenew is the default implementation of AuthorizeSSHRenew. It
// will return an error if the provisioner has the renewal disabled, if the
// certificate is not yet valid or if the certificate is expired and renew after
// expiry is disabled.
func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
if p.Claimer.IsDisableRenewal() {
return errs.Unauthorized("renew is disabled for provisioner '%s'", p.GetName())
}
unixNow := time.Now().Unix()
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
return errs.Unauthorized("certificate is not yet valid")
}
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() {
return errs.Unauthorized("certificate has expired")
}
return nil
}
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
// SanitizeStringSlices removes duplicated an empty strings.
func SanitizeStringSlices(original []string) []string {
output := []string{}
seen := make(map[string]struct{})
for _, entry := range original {
if entry == "" {
continue
}
if _, value := seen[entry]; !value {
seen[entry] = struct{}{}
output = append(output, entry)
}
}
return output
}
// SanitizeSSHUserPrincipal grabs an email or a string with the format
// local@domain and returns a sanitized version of the local, valid to be used
// as a user name. If the email starts with a letter between a and z, the
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
func SanitizeSSHUserPrincipal(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i]
}
return strings.Map(func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= '0' && r <= '9':
return r
case r == '-':
return '-'
case r == '.': // drop dots
return -1
default:
return '_'
}
}, strings.ToLower(email))
}

View file

@ -0,0 +1,391 @@
package provisioner
import (
"context"
"crypto/x509"
"fmt"
"reflect"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
var trueValue = true
func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer {
t.Helper()
c, err := NewClaimer(claims, global)
if err != nil {
t.Fatal(err)
}
return c
}
func mustDuration(t *testing.T, s string) *Duration {
t.Helper()
d, err := NewDuration(s)
if err != nil {
t.Fatal(err)
}
return d
}
func TestNewController(t *testing.T) {
type args struct {
p Interface
claims *Claims
config Config
}
tests := []struct {
name string
args args
want *Controller
wantErr bool
}{
{"ok", args{&JWK{}, nil, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, false},
{"ok with claims", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, globalProvisionerClaims),
}, false},
{"fail claimer", args{&JWK{}, &Claims{
MinTLSDur: mustDuration(t, "24h"),
MaxTLSDur: mustDuration(t, "2h"),
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
if (err != nil) != tt.wantErr {
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewController() = %v, want %v", got, tt.want)
}
})
}
}
func TestController_GetIdentity(t *testing.T) {
ctx := context.Background()
type fields struct {
Interface Interface
IdentityFunc GetIdentityFunc
}
type args struct {
ctx context.Context
email string
}
tests := []struct {
name string
fields fields
args args
want *Identity
wantErr bool
}{
{"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{
Usernames: []string{"jane", "jane@doe.org"},
}, false},
{"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
return &Identity{Usernames: []string{"jane"}}, nil
}}, args{ctx, "jane@doe.org"}, &Identity{
Usernames: []string{"jane"},
}, false},
{"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true},
{"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
return nil, fmt.Errorf("an error")
}}, args{ctx, "jane@doe.org"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
IdentityFunc: tt.fields.IdentityFunc,
}
got, err := c.GetIdentity(tt.args.ctx, tt.args.email)
if (err != nil) != tt.wantErr {
t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want)
}
})
}
}
func TestController_AuthorizeRenew(t *testing.T) {
ctx := context.Background()
now := time.Now().Truncate(time.Second)
type fields struct {
Interface Interface
Claimer *Claimer
AuthorizeRenewFunc AuthorizeRenewFunc
}
type args struct {
ctx context.Context
cert *x509.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return nil
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return nil
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, false},
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(time.Hour),
NotAfter: now.Add(2 * time.Hour),
}}, true},
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, true},
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return fmt.Errorf("an error")
}}, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
Claimer: tt.fields.Claimer,
AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc,
}
if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestController_AuthorizeSSHRenew(t *testing.T) {
ctx := context.Background()
now := time.Now()
type fields struct {
Interface Interface
Claimer *Claimer
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
}
type args struct {
ctx context.Context
cert *ssh.Certificate
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return nil
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return nil
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, false},
{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(time.Hour).Unix()),
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
}}, true},
{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, true},
{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return fmt.Errorf("an error")
}}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Controller{
Interface: tt.fields.Interface,
Claimer: tt.fields.Claimer,
AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc,
}
if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestDefaultAuthorizeRenew(t *testing.T) {
ctx := context.Background()
now := time.Now().Truncate(time.Second)
type args struct {
ctx context.Context
p *Controller
cert *x509.Certificate
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, false},
{"fail disabled", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
{"fail not yet valid", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(time.Hour),
NotAfter: now.Add(2 * time.Hour),
}}, true},
{"fail expired", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{
NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestDefaultAuthorizeSSHRenew(t *testing.T) {
ctx := context.Background()
now := time.Now()
type args struct {
ctx context.Context
p *Controller
cert *ssh.Certificate
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, nil, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false},
{"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, false},
{"fail disabled", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, true},
{"fail not yet valid", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(time.Hour).Unix()),
ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
}}, true},
{"fail expired", args{ctx, &Controller{
Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -0,0 +1,73 @@
package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
)
var (
// StepOIDRoot is the root OID for smallstep.
StepOIDRoot = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64}
// StepOIDProvisioner is the OID for the provisioner extension.
StepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(StepOIDRoot, 1)...)
)
// Extension is the Go representation of the provisioner extension.
type Extension struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
type extensionASN1 struct {
Type int
Name []byte
CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"`
}
// Marshal marshals the extension using encoding/asn1.
func (e *Extension) Marshal() ([]byte, error) {
return asn1.Marshal(extensionASN1{
Type: int(e.Type),
Name: []byte(e.Name),
CredentialID: []byte(e.CredentialID),
KeyValuePairs: e.KeyValuePairs,
})
}
// ToExtension returns the pkix.Extension representation of the provisioner
// extension.
func (e *Extension) ToExtension() (pkix.Extension, error) {
b, err := e.Marshal()
if err != nil {
return pkix.Extension{}, err
}
return pkix.Extension{
Id: StepOIDProvisioner,
Value: b,
}, nil
}
// GetProvisionerExtension goes through all the certificate extensions and
// returns the provisioner extension (1.3.6.1.4.1.37476.9000.64.1).
func GetProvisionerExtension(cert *x509.Certificate) (*Extension, bool) {
for _, e := range cert.Extensions {
if e.Id.Equal(StepOIDProvisioner) {
var provisioner extensionASN1
if _, err := asn1.Unmarshal(e.Value, &provisioner); err != nil {
return nil, false
}
return &Extension{
Type: Type(provisioner.Type),
Name: string(provisioner.Name),
CredentialID: string(provisioner.CredentialID),
KeyValuePairs: provisioner.KeyValuePairs,
}, true
}
}
return nil, false
}

View file

@ -0,0 +1,158 @@
package provisioner
import (
"crypto/x509"
"crypto/x509/pkix"
"reflect"
"testing"
"go.step.sm/crypto/pemutil"
)
func TestExtension_Marshal(t *testing.T) {
type fields struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
tests := []struct {
name string
fields fields
want []byte
wantErr bool
}{
{"ok", fields{TypeJWK, "name", "credentialID", nil}, []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
}, false},
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, []byte{
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
0x13, 0x03, 0x62, 0x61, 0x72,
}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Extension{
Type: tt.fields.Type,
Name: tt.fields.Name,
CredentialID: tt.fields.CredentialID,
KeyValuePairs: tt.fields.KeyValuePairs,
}
got, err := e.Marshal()
if (err != nil) != tt.wantErr {
t.Errorf("Extension.Marshal() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Extension.Marshal() = %x, want %v", got, tt.want)
}
})
}
}
func TestExtension_ToExtension(t *testing.T) {
type fields struct {
Type Type
Name string
CredentialID string
KeyValuePairs []string
}
tests := []struct {
name string
fields fields
want pkix.Extension
wantErr bool
}{
{"ok", fields{TypeJWK, "name", "credentialID", nil}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
},
}, false},
{"ok empty pairs", fields{TypeJWK, "name", "credentialID", []string{}}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x17, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44,
},
}, false},
{"ok with pairs", fields{TypeJWK, "name", "credentialID", []string{"foo", "bar"}}, pkix.Extension{
Id: StepOIDProvisioner,
Value: []byte{
0x30, 0x23, 0x02, 0x01, 0x01, 0x04, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x04, 0x0c, 0x63, 0x72, 0x65,
0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x49,
0x44, 0x30, 0x0a, 0x13, 0x03, 0x66, 0x6f, 0x6f,
0x13, 0x03, 0x62, 0x61, 0x72,
},
}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &Extension{
Type: tt.fields.Type,
Name: tt.fields.Name,
CredentialID: tt.fields.CredentialID,
KeyValuePairs: tt.fields.KeyValuePairs,
}
got, err := e.ToExtension()
if (err != nil) != tt.wantErr {
t.Errorf("Extension.ToExtension() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Extension.ToExtension() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetProvisionerExtension(t *testing.T) {
mustCertificate := func(fn string) *x509.Certificate {
cert, err := pemutil.ReadCertificate(fn)
if err != nil {
t.Fatal(err)
}
return cert
}
type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
args args
want *Extension
want1 bool
}{
{"ok", args{mustCertificate("testdata/certs/good-extension.crt")}, &Extension{
Type: TypeJWK,
Name: "mariano@smallstep.com",
CredentialID: "nvgnR8wSzpUlrt_tC3mvrhwhBx9Y7T1WL_JjcFVWYBQ",
}, true},
{"fail unmarshal", args{mustCertificate("testdata/certs/bad-extension.crt")}, nil, false},
{"missing extension", args{mustCertificate("testdata/certs/aws.crt")}, nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1 := GetProvisionerExtension(tt.args.cert)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetProvisionerExtension() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("GetProvisionerExtension() got1 = %v, want %v", got1, tt.want1)
}
})
}
}

View file

@ -89,10 +89,9 @@ type GCP struct {
InstanceAge Duration `json:"instanceAge,omitempty"` InstanceAge Duration `json:"instanceAge,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
config *gcpConfig config *gcpConfig
keyStore *keyStore keyStore *keyStore
audiences Audiences ctl *Controller
x509Policy policy.X509Policy x509Policy policy.X509Policy
sshHostPolicy policy.HostPolicy sshHostPolicy policy.HostPolicy
} }
@ -197,8 +196,7 @@ func (p *GCP) GetIdentityToken(subject, caURL string) (string, error) {
} }
// Init validates and initializes the GCP provisioner. // Init validates and initializes the GCP provisioner.
func (p *GCP) Init(config Config) error { func (p *GCP) Init(config Config) (err error) {
var err error
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -207,16 +205,13 @@ func (p *GCP) Init(config Config) error {
case p.InstanceAge.Value() < 0: case p.InstanceAge.Value() < 0:
return errors.New("provisioner instanceAge cannot be negative") return errors.New("provisioner instanceAge cannot be negative")
} }
// Initialize config // Initialize config
p.assertConfig() p.assertConfig()
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize key store // Initialize key store
p.keyStore, err = newKeyStore(p.config.CertsURL) if p.keyStore, err = newKeyStore(p.config.CertsURL); err != nil {
if err != nil { return
return err
} }
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
@ -229,8 +224,9 @@ func (p *GCP) Init(config Config) error {
return err return err
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
return nil p.ctl, err = NewController(p, p.Claims, config)
return
} }
// AuthorizeSign validates the given token and returns the sign options that // AuthorizeSign validates the given token and returns the sign options that
@ -282,20 +278,17 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName), newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
), nil ), nil
} }
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *GCP) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("gcp.AuthorizeRenew; renew is disabled for gcp provisioner '%s'", p.GetName())
}
return nil
} }
// assertConfig initializes the config if it has not been initialized. // assertConfig initializes the config if it has not been initialized.
@ -342,7 +335,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
} }
// validate audiences with the defaults // validate audiences with the defaults
if !matchesAudience(claims.Audience, p.audiences.Sign) { if !matchesAudience(claims.Audience, p.ctl.Audiences.Sign) {
return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)") return nil, errs.Unauthorized("gcp.authorizeToken; invalid gcp token - invalid audience claim (aud)")
} }
@ -397,7 +390,7 @@ func (p *GCP) authorizeToken(token string) (*gcpPayload, error) {
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("gcp.AuthorizeSSHSign; sshCA is disabled for gcp provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token) claims, err := p.authorizeToken(token)
@ -445,11 +438,11 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -549,18 +549,18 @@ func TestGCP_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeGCP)) assert.Equals(t, v.Type, TypeGCP)
assert.Equals(t, v.Name, tt.gcp.GetName()) assert.Equals(t, v.Name, tt.gcp.GetName())
assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0]) assert.Equals(t, v.CredentialID, tt.gcp.ServiceAccounts[0])
assert.Len(t, 4, v.KeyValuePairs) assert.Len(t, 4, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.gcp.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.gcp.ctl.Claimer.DefaultTLSCertDuration())
case commonNameSliceValidator: case commonNameSliceValidator:
assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) assert.Equals(t, []string(v), []string{"instance-name", "instance-id", "instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.gcp.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.gcp.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.gcp.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.gcp.ctl.Claimer.MaxTLSCertDuration())
case ipAddressesValidator: case ipAddressesValidator:
assert.Equals(t, v, nil) assert.Equals(t, v, nil)
case emailAddressesValidator: case emailAddressesValidator:
@ -597,7 +597,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p3.Claims = &Claims{EnableSSHCA: &disable} p3.Claims = &Claims{EnableSSHCA: &disable}
p3.claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims) p3.ctl.Claimer, err = NewClaimer(p3.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
t1, err := generateGCPToken(p1.ServiceAccounts[0], t1, err := generateGCPToken(p1.ServiceAccounts[0],
@ -624,7 +624,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedHostOptions := &SignSSHOptions{ expectedHostOptions := &SignSSHOptions{
CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}, CertType: "host", Principals: []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(hostDuration)),
@ -700,6 +700,7 @@ func TestGCP_AuthorizeSSHSign(t *testing.T) {
} }
func TestGCP_AuthorizeRenew(t *testing.T) { func TestGCP_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateGCP() p1, err := generateGCP()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateGCP() p2, err := generateGCP()
@ -708,7 +709,7 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -721,8 +722,14 @@ func TestGCP_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renewal-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renewal-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -7,11 +7,13 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
) )
// jwtPayload extends jwt.Claims with step attributes. // jwtPayload extends jwt.Claims with step attributes.
@ -36,8 +38,7 @@ type JWK struct {
EncryptedKey string `json:"encryptedKey,omitempty"` EncryptedKey string `json:"encryptedKey,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer ctl *Controller
audiences Audiences
x509Policy policy.X509Policy x509Policy policy.X509Policy
sshHostPolicy policy.HostPolicy sshHostPolicy policy.HostPolicy
sshUserPolicy policy.UserPolicy sshUserPolicy policy.UserPolicy
@ -102,11 +103,6 @@ func (p *JWK) Init(config Config) (err error) {
return errors.New("provisioner key cannot be empty") return errors.New("provisioner key cannot be empty")
} }
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
@ -122,8 +118,8 @@ func (p *JWK) Init(config Config) (err error) {
return err return err
} }
p.audiences = config.Audiences p.ctl, err = NewController(p, p.Claims, config)
return err return
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -165,14 +161,14 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke) // TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign")
} }
@ -199,12 +195,12 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID), newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
commonNameValidator(claims.Subject), commonNameValidator(claims.Subject),
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
defaultSANsValidator(claims.SANs), defaultSANsValidator(claims.SANs),
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
}, nil }, nil
} }
@ -214,12 +210,13 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { // if p.claimer.IsDisableRenewal() {
return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName()) // return errs.Unauthorized("jwk.AuthorizeRenew; renew is disabled for jwk provisioner '%s'", p.GetName())
} // }
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew) // TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew)
//return p.authorizeRenew(cert) //return p.authorizeRenew(cert)
return nil // return nil
return p.ctl.AuthorizeRenew(ctx, cert)
} }
// func (p *JWK) authorizeRenew(cert *x509.Certificate) error { // func (p *JWK) authorizeRenew(cert *x509.Certificate) error {
@ -232,10 +229,10 @@ func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("jwk.AuthorizeSSHSign; sshCA is disabled for jwk provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHSign")
} }
@ -292,11 +289,11 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
return append(signOptions, return append(signOptions,
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed
@ -306,7 +303,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.SSHRevoke) _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke)
// TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke) // TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
} }

View file

@ -76,13 +76,13 @@ func TestJWK_Init(t *testing.T) {
}, },
"fail-bad-claims": func(t *testing.T) ProvisionerValidateTest { "fail-bad-claims": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences, Claims: &Claims{DefaultTLSDur: &Duration{0}}}, p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, Claims: &Claims{DefaultTLSDur: &Duration{0}}},
err: errors.New("claims: MinTLSCertDuration must be greater than 0"), err: errors.New("claims: MinTLSCertDuration must be greater than 0"),
} }
}, },
"ok": func(t *testing.T) ProvisionerValidateTest { "ok": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}, audiences: testAudiences}, p: &JWK{Name: "foo", Type: "bar", Key: &jose.JSONWebKey{}},
} }
}, },
} }
@ -300,18 +300,18 @@ func TestJWK_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeJWK)) assert.Equals(t, v.Type, TypeJWK)
assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.Name, tt.prov.GetName())
assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID) assert.Equals(t, v.CredentialID, tt.prov.Key.KeyID)
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
case commonNameValidator: case commonNameValidator:
assert.Equals(t, string(v), "subject") assert.Equals(t, string(v), "subject")
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case defaultSANsValidator: case defaultSANsValidator:
assert.Equals(t, []string(v), tt.sans) assert.Equals(t, []string(v), tt.sans)
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
@ -327,6 +327,7 @@ func TestJWK_AuthorizeSign(t *testing.T) {
} }
func TestJWK_AuthorizeRenew(t *testing.T) { func TestJWK_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateJWK() p1, err := generateJWK()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateJWK() p2, err := generateJWK()
@ -335,7 +336,7 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -348,8 +349,14 @@ func TestJWK_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -375,7 +382,7 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p2.Claims = &Claims{EnableSSHCA: &disable} p2.Claims = &Claims{EnableSSHCA: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
jwk, err := decryptJSONWebKey(p1.EncryptedKey) jwk, err := decryptJSONWebKey(p1.EncryptedKey)
@ -404,8 +411,8 @@ func TestJWK_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration() userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{ expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name"}, CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),
@ -487,8 +494,8 @@ func TestJWK_AuthorizeSign_SSHOptions(t *testing.T) {
signer, err := generateJSONWebKey() signer, err := generateJSONWebKey()
assert.FatalError(t, err) assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration() userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{ expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name"}, CertType: "user", Principals: []string{"name"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),

View file

@ -10,12 +10,14 @@ import (
"net/http" "net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
) )
// NOTE: There can be at most one kubernetes service account provisioner configured // NOTE: There can be at most one kubernetes service account provisioner configured
@ -43,16 +45,15 @@ type k8sSAPayload struct {
// entity trusted to make signature requests. // entity trusted to make signature requests.
type K8sSA struct { type K8sSA struct {
*base *base
ID string `json:"-"` ID string `json:"-"`
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
PubKeys []byte `json:"publicKeys,omitempty"` PubKeys []byte `json:"publicKeys,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
audiences Audiences
//kauthn kauthn.AuthenticationV1Interface //kauthn kauthn.AuthenticationV1Interface
pubKeys []interface{} pubKeys []interface{}
ctl *Controller
x509Policy policy.X509Policy x509Policy policy.X509Policy
sshHostPolicy policy.HostPolicy sshHostPolicy policy.HostPolicy
sshUserPolicy policy.UserPolicy sshUserPolicy policy.UserPolicy
@ -143,11 +144,6 @@ func (p *K8sSA) Init(config Config) (err error) {
p.kauthn = k8s.AuthenticationV1() p.kauthn = k8s.AuthenticationV1()
*/ */
// Update claims with global ones
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
@ -163,8 +159,8 @@ func (p *K8sSA) Init(config Config) (err error) {
return err return err
} }
p.audiences = config.Audiences p.ctl, err = NewController(p, p.Claims, config)
return err return
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -231,13 +227,13 @@ func (p *K8sSA) authorizeToken(token string, audiences []string) (*k8sSAPayload,
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error { func (p *K8sSA) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSign")
} }
@ -260,28 +256,25 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeK8sSA, p.Name, ""), newProvisionerExtensionOption(TypeK8sSA, p.Name, ""),
profileDefaultDuration(p.claimer.DefaultTLSCertDuration()), profileDefaultDuration(p.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
}, nil }, nil
} }
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *K8sSA) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeSSHSign validates an request for an SSH certificate. // AuthorizeSSHSign validates an request for an SSH certificate.
func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "k8ssa.AuthorizeSSHSign")
} }
@ -303,11 +296,11 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
// Require type, key-id and principals in the SignSSHOptions. // Require type, key-id and principals in the SignSSHOptions.
&sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true},
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -179,6 +179,7 @@ func TestK8sSA_AuthorizeRevoke(t *testing.T) {
} }
func TestK8sSA_AuthorizeRenew(t *testing.T) { func TestK8sSA_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct { type test struct {
p *K8sSA p *K8sSA
cert *x509.Certificate cert *x509.Certificate
@ -192,21 +193,27 @@ func TestK8sSA_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("k8ssa.AuthorizeRenew; renew is disabled for k8sSA provisioner '%s'", p.GetName()), err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
p, err := generateK8sSA(nil) p, err := generateK8sSA(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
cert: &x509.Certificate{}, cert: &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
},
} }
}, },
} }
@ -275,16 +282,16 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeK8sSA)) assert.Equals(t, v.Type, TypeK8sSA)
assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "") assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tc.p.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine) assert.Equals(t, nil, v.policyEngine)
default: default:
@ -313,7 +320,7 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p.Claims = &Claims{EnableSSHCA: &disable} p.Claims = &Claims{EnableSSHCA: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
@ -365,11 +372,11 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
case *sshCertOptionsRequireValidator: case *sshCertOptionsRequireValidator:
assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}) assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true})
case *sshCertValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshDefaultPublicKeyValidator: case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator: case *sshCertDefaultValidator:
case *sshDefaultDuration: case *sshDefaultDuration:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshNamePolicyValidator: case *sshNamePolicyValidator:
assert.Equals(t, nil, v.userPolicyEngine) assert.Equals(t, nil, v.userPolicyEngine)
assert.Equals(t, nil, v.hostPolicyEngine) assert.Equals(t, nil, v.hostPolicyEngine)

View file

@ -41,16 +41,15 @@ type Nebula struct {
Roots []byte `json:"roots"` Roots []byte `json:"roots"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer
caPool *nebula.NebulaCAPool caPool *nebula.NebulaCAPool
audiences Audiences ctl *Controller
x509Policy policy.X509Policy x509Policy policy.X509Policy
sshHostPolicy policy.HostPolicy sshHostPolicy policy.HostPolicy
sshUserPolicy policy.UserPolicy sshUserPolicy policy.UserPolicy
} }
// Init verifies and initializes the Nebula provisioner. // Init verifies and initializes the Nebula provisioner.
func (p *Nebula) Init(config Config) error { func (p *Nebula) Init(config Config) (err error) {
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -60,18 +59,11 @@ func (p *Nebula) Init(config Config) error {
return errors.New("provisioner root(s) cannot be empty") return errors.New("provisioner root(s) cannot be empty")
} }
var err error
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots) p.caPool, err = nebula.NewCAPoolFromBytes(p.Roots)
if err != nil { if err != nil {
return errs.InternalServer("failed to create ca pool: %v", err) return errs.InternalServer("failed to create ca pool: %v", err)
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
@ -87,7 +79,9 @@ func (p *Nebula) Init(config Config) error {
return err return err
} }
return nil config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config)
return
} }
// GetID returns the provisioner id. // GetID returns the provisioner id.
@ -139,7 +133,7 @@ func (p *Nebula) GetEncryptedKey() (kid, key string, ok bool) {
// AuthorizeSign returns the list of SignOption for a Sign request. // AuthorizeSign returns the list of SignOption for a Sign request.
func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
crt, claims, err := p.authorizeToken(token, p.audiences.Sign) crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -173,7 +167,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeNebula, p.Name, ""), newProvisionerExtensionOption(TypeNebula, p.Name, ""),
profileLimitDuration{ profileLimitDuration{
def: p.claimer.DefaultTLSCertDuration(), def: p.ctl.Claimer.DefaultTLSCertDuration(),
notBefore: crt.Details.NotBefore, notBefore: crt.Details.NotBefore,
notAfter: crt.Details.NotAfter, notAfter: crt.Details.NotAfter,
}, },
@ -184,7 +178,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
IPs: crt.Details.Ips, IPs: crt.Details.Ips,
}, },
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
}, nil }, nil
} }
@ -192,11 +186,11 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
// Currently the Nebula provisioner only grants host SSH certificates. // Currently the Nebula provisioner only grants host SSH certificates.
func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) return nil, errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
} }
crt, claims, err := p.authorizeToken(token, p.audiences.SSHSign) crt, claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -274,11 +268,11 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
return append(signOptions, return append(signOptions,
templateOptions, templateOptions,
// Checks the validity bounds, and set the validity if has not been set. // Checks the validity bounds, and set the validity if has not been set.
&sshLimitDuration{p.claimer, crt.Details.NotAfter}, &sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter},
// Validate public key. // Validate public key.
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed
@ -288,23 +282,20 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error { func (p *Nebula) AuthorizeRenew(ctx context.Context, crt *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, crt)
return errs.Unauthorized("renew is disabled for nebula provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeRevoke returns an error if the token is not valid. // AuthorizeRevoke returns an error if the token is not valid.
func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error { func (p *Nebula) AuthorizeRevoke(ctx context.Context, token string) error {
return p.validateToken(token, p.audiences.Revoke) return p.validateToken(token, p.ctl.Audiences.Revoke)
} }
// AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid. // AuthorizeSSHRevoke returns an error if SSH is disabled or the token is invalid.
func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *Nebula) AuthorizeSSHRevoke(ctx context.Context, token string) error {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name) return errs.Unauthorized("ssh is disabled for nebula provisioner '%s'", p.Name)
} }
if _, _, err := p.authorizeToken(token, p.audiences.SSHRevoke); err != nil { if _, _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke); err != nil {
return err return err
} }
return nil return nil

View file

@ -327,7 +327,7 @@ func TestNebula_GetIDForToken(t *testing.T) {
func TestNebula_GetTokenID(t *testing.T) { func TestNebula_GetTokenID(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer) c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer)
t1 := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv) t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv)
_, claims, err := parseToken(t1) _, claims, err := parseToken(t1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -441,8 +441,8 @@ func TestNebula_AuthorizeSign(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, crt, priv)
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], now(), nil, crt, priv) okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), nil, crt, priv)
pBadOptions, _, _ := mustNebulaProvisioner(t) pBadOptions, _, _ := mustNebulaProvisioner(t)
pBadOptions.caPool = p.caPool pBadOptions.caPool = p.caPool
@ -483,20 +483,20 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
// Ok provisioner // Ok provisioner
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "host", CertType: "host",
KeyID: "test.lan", KeyID: "test.lan",
Principals: []string{"test.lan", "10.1.0.1"}, Principals: []string{"test.lan", "10.1.0.1"},
}, crt, priv) }, crt, priv)
okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), nil, crt, priv) okNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), nil, crt, priv)
okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ okWithValidity := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)), ValidAfter: NewTimeDuration(now().Add(1 * time.Hour)),
ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)), ValidBefore: NewTimeDuration(now().Add(10 * time.Hour)),
}, crt, priv) }, crt, priv)
failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ failUserCert := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "user", CertType: "user",
}, crt, priv) }, crt, priv)
failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], now(), &SignSSHOptions{ failPrincipals := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], now(), &SignSSHOptions{
CertType: "host", CertType: "host",
KeyID: "test.lan", KeyID: "test.lan",
Principals: []string{"test.lan", "10.1.0.1", "foo.bar"}, Principals: []string{"test.lan", "10.1.0.1", "foo.bar"},
@ -549,6 +549,8 @@ func TestNebula_AuthorizeSSHSign(t *testing.T) {
func TestNebula_AuthorizeRenew(t *testing.T) { func TestNebula_AuthorizeRenew(t *testing.T) {
ctx := context.TODO() ctx := context.TODO()
now := time.Now().Truncate(time.Second)
// Ok provisioner // Ok provisioner
p, _, _ := mustNebulaProvisioner(t) p, _, _ := mustNebulaProvisioner(t)
@ -567,8 +569,14 @@ func TestNebula_AuthorizeRenew(t *testing.T) {
args args args args
wantErr bool wantErr bool
}{ }{
{"ok", p, args{ctx, &x509.Certificate{}}, false}, {"ok", p, args{ctx, &x509.Certificate{
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{}}, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, false},
{"fail disabled", pDisabled, args{ctx, &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -584,12 +592,12 @@ func TestNebula_AuthorizeRevoke(t *testing.T) {
// Ok provisioner // Ok provisioner
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
// Fail different CA // Fail different CA
nc, signer := mustNebulaCA(t) nc, signer := mustNebulaCA(t)
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
failToken := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Revoke[0], now(), nil, crt, priv) failToken := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Revoke[0], now(), nil, crt, priv)
type args struct { type args struct {
ctx context.Context ctx context.Context
@ -618,12 +626,12 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
// Ok provisioner // Ok provisioner
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) ok := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
// Fail different CA // Fail different CA
nc, signer := mustNebulaCA(t) nc, signer := mustNebulaCA(t)
crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer) crt, priv = mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, nc, signer)
failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRevoke[0], now(), nil, crt, priv) failToken := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRevoke[0], now(), nil, crt, priv)
// Provisioner with SSH disabled // Provisioner with SSH disabled
var bFalse bool var bFalse bool
@ -657,7 +665,7 @@ func TestNebula_AuthorizeSSHRevoke(t *testing.T) {
func TestNebula_AuthorizeSSHRenew(t *testing.T) { func TestNebula_AuthorizeSSHRenew(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRenew[0], now(), nil, crt, priv) t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRenew[0], now(), nil, crt, priv)
type args struct { type args struct {
ctx context.Context ctx context.Context
@ -689,7 +697,7 @@ func TestNebula_AuthorizeSSHRenew(t *testing.T) {
func TestNebula_AuthorizeSSHRekey(t *testing.T) { func TestNebula_AuthorizeSSHRekey(t *testing.T) {
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHRekey[0], now(), nil, crt, priv) t1 := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHRekey[0], now(), nil, crt, priv)
type args struct { type args struct {
ctx context.Context ctx context.Context
@ -726,20 +734,20 @@ func TestNebula_authorizeToken(t *testing.T) {
t1 := now() t1 := now()
p, ca, signer := mustNebulaProvisioner(t) p, ca, signer := mustNebulaProvisioner(t)
crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer)
ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) ok := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv) okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1, nil, crt, priv)
okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{ okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, &SignSSHOptions{
CertType: "host", CertType: "host",
KeyID: "test.lan", KeyID: "test.lan",
Principals: []string{"test.lan"}, Principals: []string{"test.lan"},
}, crt, priv) }, crt, priv)
okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv) okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.ctl.Audiences.SSHSign[0], t1, nil, crt, priv)
// Token with errors // Token with errors
failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv)
failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) failIssuer := mustNebulaToken(t, "test.lan", "foo", p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv) failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv)
failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) failSubject := mustNebulaToken(t, "", p.Name, p.ctl.Audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv)
// Not a nebula token // Not a nebula token
jwk, err := generateJSONWebKey() jwk, err := generateJSONWebKey()
@ -761,7 +769,7 @@ func TestNebula_authorizeToken(t *testing.T) {
IssuedAt: jose.NewNumericDate(t1), IssuedAt: jose.NewNumericDate(t1),
NotBefore: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1),
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
Audience: []string{p.audiences.Sign[0]}, Audience: []string{p.ctl.Audiences.Sign[0]},
} }
sshClaims := jose.Claims{ sshClaims := jose.Claims{
ID: "[REPLACEME]", ID: "[REPLACEME]",
@ -770,7 +778,7 @@ func TestNebula_authorizeToken(t *testing.T) {
IssuedAt: jose.NewNumericDate(t1), IssuedAt: jose.NewNumericDate(t1),
NotBefore: jose.NewNumericDate(t1), NotBefore: jose.NewNumericDate(t1),
Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)),
Audience: []string{p.audiences.SSHSign[0]}, Audience: []string{p.ctl.Audiences.SSHSign[0]},
} }
type args struct { type args struct {
@ -785,14 +793,14 @@ func TestNebula_authorizeToken(t *testing.T) {
want1 *jwtPayload want1 *jwtPayload
wantErr bool wantErr bool
}{ }{
{"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{ {"ok x509", p, args{ok, p.ctl.Audiences.Sign}, crt, &jwtPayload{
Claims: x509Claims, Claims: x509Claims,
SANs: []string{"10.1.0.1"}, SANs: []string{"10.1.0.1"},
}, false}, }, false},
{"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{ {"ok x509 no sans", p, args{okNoSANs, p.ctl.Audiences.Sign}, crt, &jwtPayload{
Claims: x509Claims, Claims: x509Claims,
}, false}, }, false},
{"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{ {"ok ssh", p, args{okSSH, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
Claims: sshClaims, Claims: sshClaims,
Step: &stepPayload{ Step: &stepPayload{
SSH: &SignSSHOptions{ SSH: &SignSSHOptions{
@ -802,16 +810,16 @@ func TestNebula_authorizeToken(t *testing.T) {
}, },
}, },
}, false}, }, false},
{"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{ {"ok ssh no principals", p, args{okSSHNoOptions, p.ctl.Audiences.SSHSign}, crt, &jwtPayload{
Claims: sshClaims, Claims: sshClaims,
}, false}, }, false},
{"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true}, {"fail parse", p, args{"bad.token", p.ctl.Audiences.Sign}, nil, nil, true},
{"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true}, {"fail header", p, args{simpleToken, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true}, {"fail verify", p2, args{ok, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true}, {"fail claims nbf", p, args{failNotBefore, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true}, {"fail claims iss", p, args{failIssuer, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true}, {"fail claims aud", p, args{failAudience, p.ctl.Audiences.Sign}, nil, nil, true},
{"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true}, {"fail claims sub", p, args{failSubject, p.ctl.Audiences.Sign}, nil, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -12,11 +12,13 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
) )
// openIDConfiguration contains the necessary properties in the // openIDConfiguration contains the necessary properties in the
@ -93,8 +95,8 @@ type OIDC struct {
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
configuration openIDConfiguration configuration openIDConfiguration
keyStore *keyStore keyStore *keyStore
claimer *Claimer
getIdentityFunc GetIdentityFunc getIdentityFunc GetIdentityFunc
ctl *Controller
x509Policy policy.X509Policy x509Policy policy.X509Policy
sshHostPolicy policy.HostPolicy sshHostPolicy policy.HostPolicy
sshUserPolicy policy.UserPolicy sshUserPolicy policy.UserPolicy
@ -176,11 +178,6 @@ func (o *OIDC) Init(config Config) (err error) {
} }
} }
// Update claims with global ones
if o.claimer, err = NewClaimer(o.Claims, config.Claims); err != nil {
return err
}
// Decode and validate openid-configuration endpoint // Decode and validate openid-configuration endpoint
u, err := url.Parse(o.ConfigurationEndpoint) u, err := url.Parse(o.ConfigurationEndpoint)
if err != nil { if err != nil {
@ -227,7 +224,8 @@ func (o *OIDC) Init(config Config) (err error) {
return err return err
} }
return nil o.ctl, err = NewController(o, o.Claims, config)
return
} }
// ValidatePayload validates the given token payload. // ValidatePayload validates the given token payload.
@ -379,10 +377,10 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID), newProvisionerExtensionOption(TypeOIDC, o.Name, o.ClientID),
profileDefaultDuration(o.claimer.DefaultTLSCertDuration()), profileDefaultDuration(o.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(o.claimer.MinTLSCertDuration(), o.claimer.MaxTLSCertDuration()), newValidityValidator(o.ctl.Claimer.MinTLSCertDuration(), o.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(o.x509Policy), newX509NamePolicyValidator(o.x509Policy),
}, nil }, nil
} }
@ -392,15 +390,12 @@ func (o *OIDC) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (o *OIDC) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if o.claimer.IsDisableRenewal() { return o.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("oidc.AuthorizeRenew; renew is disabled for oidc provisioner '%s'", o.GetName())
}
return nil
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !o.claimer.IsSSHCAEnabled() { if !o.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName()) return nil, errs.Unauthorized("oidc.AuthorizeSSHSign; sshCA is disabled for oidc provisioner '%s'", o.GetName())
} }
claims, err := o.authorizeToken(token) claims, err := o.authorizeToken(token)
@ -415,7 +410,7 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
// Get the identity using either the default identityFunc or one injected // Get the identity using either the default identityFunc or one injected
// externally. Note that the PreferredUsername might be empty. // externally. Note that the PreferredUsername might be empty.
// TBD: Would preferred_username present a safety issue here? // TBD: Would preferred_username present a safety issue here?
iden, err := o.getIdentityFunc(ctx, o, claims.Email) iden, err := o.ctl.GetIdentity(ctx, claims.Email)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "oidc.AuthorizeSSHSign")
} }
@ -466,11 +461,11 @@ func (o *OIDC) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption
return append(signOptions, return append(signOptions,
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{o.claimer}, &sshDefaultDuration{o.ctl.Claimer},
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{o.claimer}, &sshCertValidityValidator{o.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -327,16 +327,16 @@ func TestOIDC_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeOIDC)) assert.Equals(t, v.Type, TypeOIDC)
assert.Equals(t, v.Name, tt.prov.GetName()) assert.Equals(t, v.Name, tt.prov.GetName())
assert.Equals(t, v.CredentialID, tt.prov.ClientID) assert.Equals(t, v.CredentialID, tt.prov.ClientID)
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileDefaultDuration: case profileDefaultDuration:
assert.Equals(t, time.Duration(v), tt.prov.claimer.DefaultTLSCertDuration()) assert.Equals(t, time.Duration(v), tt.prov.ctl.Claimer.DefaultTLSCertDuration())
case defaultPublicKeyValidator: case defaultPublicKeyValidator:
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tt.prov.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tt.prov.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case emailOnlyIdentity: case emailOnlyIdentity:
assert.Equals(t, string(v), "name@smallstep.com") assert.Equals(t, string(v), "name@smallstep.com")
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
@ -413,6 +413,7 @@ func TestOIDC_AuthorizeRevoke(t *testing.T) {
} }
func TestOIDC_AuthorizeRenew(t *testing.T) { func TestOIDC_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
p1, err := generateOIDC() p1, err := generateOIDC()
assert.FatalError(t, err) assert.FatalError(t, err)
p2, err := generateOIDC() p2, err := generateOIDC()
@ -421,7 +422,7 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p2.Claims = &Claims{DisableRenewal: &disable} p2.Claims = &Claims{DisableRenewal: &disable}
p2.claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims) p2.ctl.Claimer, err = NewClaimer(p2.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
type args struct { type args struct {
@ -434,8 +435,14 @@ func TestOIDC_AuthorizeRenew(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{nil}, http.StatusOK, false}, {"ok", p1, args{&x509.Certificate{
{"fail/renew-disabled", p2, args{nil}, http.StatusUnauthorized, true}, NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusOK, false},
{"fail/renew-disabled", p2, args{&x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}}, http.StatusUnauthorized, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -480,7 +487,7 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
disable := false disable := false
p6.Claims = &Claims{EnableSSHCA: &disable} p6.Claims = &Claims{EnableSSHCA: &disable}
p6.claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims) p6.ctl.Claimer, err = NewClaimer(p6.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
// Update configuration endpoints and initialize // Update configuration endpoints and initialize
@ -496,10 +503,10 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
assert.FatalError(t, p4.Init(config)) assert.FatalError(t, p4.Init(config))
assert.FatalError(t, p5.Init(config)) assert.FatalError(t, p5.Init(config))
p4.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { p4.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
return &Identity{Usernames: []string{"max", "mariano"}}, nil return &Identity{Usernames: []string{"max", "mariano"}}, nil
} }
p5.getIdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) { p5.ctl.IdentityFunc = func(ctx context.Context, p Interface, email string) (*Identity, error) {
return nil, errors.New("force") return nil, errors.New("force")
} }
// Additional test needed for empty usernames and duplicate email and usernames // Additional test needed for empty usernames and duplicate email and usernames
@ -529,8 +536,8 @@ func TestOIDC_AuthorizeSSHSign(t *testing.T) {
rsa1024, err := rsa.GenerateKey(rand.Reader, 1024) rsa1024, err := rsa.GenerateKey(rand.Reader, 1024)
assert.FatalError(t, err) assert.FatalError(t, err)
userDuration := p1.claimer.DefaultUserSSHCertDuration() userDuration := p1.ctl.Claimer.DefaultUserSSHCertDuration()
hostDuration := p1.claimer.DefaultHostSSHCertDuration() hostDuration := p1.ctl.Claimer.DefaultHostSSHCertDuration()
expectedUserOptions := &SignSSHOptions{ expectedUserOptions := &SignSSHOptions{
CertType: "user", Principals: []string{"name", "name@smallstep.com"}, CertType: "user", Principals: []string{"name", "name@smallstep.com"},
ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)), ValidAfter: NewTimeDuration(tm), ValidBefore: NewTimeDuration(tm.Add(userDuration)),

View file

@ -6,7 +6,6 @@ import (
"encoding/json" "encoding/json"
stderrors "errors" stderrors "errors"
"net/url" "net/url"
"regexp"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -47,6 +46,7 @@ var ErrAllowTokenReuse = stderrors.New("allow token reuse")
// Audiences stores all supported audiences by request type. // Audiences stores all supported audiences by request type.
type Audiences struct { type Audiences struct {
Sign []string Sign []string
Renew []string
Revoke []string Revoke []string
SSHSign []string SSHSign []string
SSHRevoke []string SSHRevoke []string
@ -57,6 +57,7 @@ type Audiences struct {
// All returns all supported audiences across all request types in one list. // All returns all supported audiences across all request types in one list.
func (a Audiences) All() (auds []string) { func (a Audiences) All() (auds []string) {
auds = a.Sign auds = a.Sign
auds = append(auds, a.Renew...)
auds = append(auds, a.Revoke...) auds = append(auds, a.Revoke...)
auds = append(auds, a.SSHSign...) auds = append(auds, a.SSHSign...)
auds = append(auds, a.SSHRevoke...) auds = append(auds, a.SSHRevoke...)
@ -70,6 +71,7 @@ func (a Audiences) All() (auds []string) {
func (a Audiences) WithFragment(fragment string) Audiences { func (a Audiences) WithFragment(fragment string) Audiences {
ret := Audiences{ ret := Audiences{
Sign: make([]string, len(a.Sign)), Sign: make([]string, len(a.Sign)),
Renew: make([]string, len(a.Renew)),
Revoke: make([]string, len(a.Revoke)), Revoke: make([]string, len(a.Revoke)),
SSHSign: make([]string, len(a.SSHSign)), SSHSign: make([]string, len(a.SSHSign)),
SSHRevoke: make([]string, len(a.SSHRevoke)), SSHRevoke: make([]string, len(a.SSHRevoke)),
@ -83,6 +85,13 @@ func (a Audiences) WithFragment(fragment string) Audiences {
ret.Sign[i] = s ret.Sign[i] = s
} }
} }
for i, s := range a.Renew {
if u, err := url.Parse(s); err == nil {
ret.Renew[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
} else {
ret.Renew[i] = s
}
}
for i, s := range a.Revoke { for i, s := range a.Revoke {
if u, err := url.Parse(s); err == nil { if u, err := url.Parse(s); err == nil {
ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String() ret.Revoke[i] = u.ResolveReference(&url.URL{Fragment: fragment}).String()
@ -210,6 +219,12 @@ type Config struct {
// GetIdentityFunc is a function that returns an identity that will be // GetIdentityFunc is a function that returns an identity that will be
// used by the provisioner to populate certificate attributes. // used by the provisioner to populate certificate attributes.
GetIdentityFunc GetIdentityFunc GetIdentityFunc GetIdentityFunc
// AuthorizeRenewFunc is a function that returns nil if a given X.509
// certificate can be renewed.
AuthorizeRenewFunc AuthorizeRenewFunc
// AuthorizeSSHRenewFunc is a function that returns nil if a given SSH
// certificate can be renewed.
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
} }
type provisioner struct { type provisioner struct {
@ -278,32 +293,6 @@ func (l *List) UnmarshalJSON(data []byte) error {
return nil return nil
} }
var sshUserRegex = regexp.MustCompile("^[a-z][-a-z0-9_]*$")
// SanitizeSSHUserPrincipal grabs an email or a string with the format
// local@domain and returns a sanitized version of the local, valid to be used
// as a user name. If the email starts with a letter between a and z, the
// resulting string will match the regular expression `^[a-z][-a-z0-9_]*$`.
func SanitizeSSHUserPrincipal(email string) string {
if i := strings.LastIndex(email, "@"); i >= 0 {
email = email[:i]
}
return strings.Map(func(r rune) rune {
switch {
case r >= 'a' && r <= 'z':
return r
case r >= '0' && r <= '9':
return r
case r == '-':
return '-'
case r == '.': // drop dots
return -1
default:
return '_'
}
}, strings.ToLower(email))
}
type base struct{} type base struct{}
// AuthorizeSign returns an unimplemented error. Provisioners should overwrite // AuthorizeSign returns an unimplemented error. Provisioners should overwrite
@ -348,66 +337,12 @@ func (b *base) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certif
return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented") return nil, nil, errs.Unauthorized("provisioner.AuthorizeSSHRekey not implemented")
} }
// Identity is the type representing an externally supplied identity that is used
// by provisioners to populate certificate fields.
type Identity struct {
Usernames []string `json:"usernames"`
Permissions `json:"permissions"`
}
// Permissions defines extra extensions and critical options to grant to an SSH certificate. // Permissions defines extra extensions and critical options to grant to an SSH certificate.
type Permissions struct { type Permissions struct {
Extensions map[string]string `json:"extensions"` Extensions map[string]string `json:"extensions"`
CriticalOptions map[string]string `json:"criticalOptions"` CriticalOptions map[string]string `json:"criticalOptions"`
} }
// GetIdentityFunc is a function that returns an identity.
type GetIdentityFunc func(ctx context.Context, p Interface, email string) (*Identity, error)
// DefaultIdentityFunc return a default identity depending on the provisioner
// type. For OIDC email is always present and the usernames might
// contain empty strings.
func DefaultIdentityFunc(ctx context.Context, p Interface, email string) (*Identity, error) {
switch k := p.(type) {
case *OIDC:
// OIDC principals would be:
// ~~1. Preferred usernames.~~ Note: Under discussion, currently disabled
// 2. Sanitized local.
// 3. Raw local (if different).
// 4. Email address.
name := SanitizeSSHUserPrincipal(email)
if !sshUserRegex.MatchString(name) {
return nil, errors.Errorf("invalid principal '%s' from email '%s'", name, email)
}
usernames := []string{name}
if i := strings.LastIndex(email, "@"); i >= 0 {
usernames = append(usernames, email[:i])
}
usernames = append(usernames, email)
return &Identity{
Usernames: SanitizeStringSlices(usernames),
}, nil
default:
return nil, errors.Errorf("provisioner type '%T' not supported by identity function", k)
}
}
// SanitizeStringSlices removes duplicated an empty strings.
func SanitizeStringSlices(original []string) []string {
output := []string{}
seen := make(map[string]struct{})
for _, entry := range original {
if entry == "" {
continue
}
if _, value := seen[entry]; !value {
seen[entry] = struct{}{}
output = append(output, entry)
}
}
return output
}
// MockProvisioner for testing // MockProvisioner for testing
type MockProvisioner struct { type MockProvisioner struct {
Mret1, Mret2, Mret3 interface{} Mret1, Mret2, Mret3 interface{}

View file

@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy" "github.com/smallstep/certificates/authority/policy"
) )
@ -12,25 +13,27 @@ import (
// SCEP provisioning flow // SCEP provisioning flow
type SCEP struct { type SCEP struct {
*base *base
ID string `json:"-"` ID string `json:"-"`
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
ForceCN bool `json:"forceCN,omitempty"` ForceCN bool `json:"forceCN,omitempty"`
ChallengePassword string `json:"challenge,omitempty"` ChallengePassword string `json:"challenge,omitempty"`
Capabilities []string `json:"capabilities,omitempty"` Capabilities []string `json:"capabilities,omitempty"`
// IncludeRoot makes the provisioner return the CA root in addition to the // IncludeRoot makes the provisioner return the CA root in addition to the
// intermediate in the GetCACerts response // intermediate in the GetCACerts response
IncludeRoot bool `json:"includeRoot,omitempty"` IncludeRoot bool `json:"includeRoot,omitempty"`
// MinimumPublicKeyLength is the minimum length for public keys in CSRs // MinimumPublicKeyLength is the minimum length for public keys in CSRs
MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"` MinimumPublicKeyLength int `json:"minimumPublicKeyLength,omitempty"`
// Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7 // Numerical identifier for the ContentEncryptionAlgorithm as defined in github.com/mozilla-services/pkcs7
// at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63 // at https://github.com/mozilla-services/pkcs7/blob/33d05740a3526e382af6395d3513e73d4e66d1cb/encrypt.go#L63
// Defaults to 0, being DES-CBC // Defaults to 0, being DES-CBC
EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"` EncryptionAlgorithmIdentifier int `json:"encryptionAlgorithmIdentifier,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
claimer *Claimer ctl *Controller
x509Policy policy.X509Policy x509Policy policy.X509Policy
secretChallengePassword string secretChallengePassword string
encryptionAlgorithm int encryptionAlgorithm int
@ -78,7 +81,7 @@ func (s *SCEP) GetOptions() *Options {
// DefaultTLSCertDuration returns the default TLS cert duration enforced by // DefaultTLSCertDuration returns the default TLS cert duration enforced by
// the provisioner. // the provisioner.
func (s *SCEP) DefaultTLSCertDuration() time.Duration { func (s *SCEP) DefaultTLSCertDuration() time.Duration {
return s.claimer.DefaultTLSCertDuration() return s.ctl.Claimer.DefaultTLSCertDuration()
} }
// Init initializes and validates the fields of a SCEP type. // Init initializes and validates the fields of a SCEP type.
@ -90,11 +93,6 @@ func (s *SCEP) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty") return errors.New("provisioner name cannot be empty")
} }
// Update claims with global ones
if s.claimer, err = NewClaimer(s.Claims, config.Claims); err != nil {
return err
}
// Mask the actual challenge value, so it won't be marshaled // Mask the actual challenge value, so it won't be marshaled
s.secretChallengePassword = s.ChallengePassword s.secretChallengePassword = s.ChallengePassword
s.ChallengePassword = "*** redacted ***" s.ChallengePassword = "*** redacted ***"
@ -120,7 +118,8 @@ func (s *SCEP) Init(config Config) (err error) {
return err return err
} }
return err s.ctl, err = NewController(s, s.Claims, config)
return
} }
// AuthorizeSign does not do any verification, because all verification is handled // AuthorizeSign does not do any verification, because all verification is handled
@ -131,10 +130,10 @@ func (s *SCEP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeSCEP, s.Name, ""), newProvisionerExtensionOption(TypeSCEP, s.Name, ""),
newForceCNOption(s.ForceCN), newForceCNOption(s.ForceCN),
profileDefaultDuration(s.claimer.DefaultTLSCertDuration()), profileDefaultDuration(s.ctl.Claimer.DefaultTLSCertDuration()),
// validators // validators
newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength), newPublicKeyMinimumLengthValidator(s.MinimumPublicKeyLength),
newValidityValidator(s.claimer.MinTLSCertDuration(), s.claimer.MaxTLSCertDuration()), newValidityValidator(s.ctl.Claimer.MinTLSCertDuration(), s.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(s.x509Policy), newX509NamePolicyValidator(s.x509Policy),
}, nil }, nil
} }

View file

@ -14,11 +14,11 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/errs"
) )
// DefaultCertValidity is the default validity for a certificate if none is specified. // DefaultCertValidity is the default validity for a certificate if none is specified.
@ -432,12 +432,12 @@ var (
stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...) stepOIDProvisioner = append(asn1.ObjectIdentifier(nil), append(stepOIDRoot, 1)...)
) )
type stepProvisionerASN1 struct { // type stepProvisionerASN1 struct {
Type int // Type int
Name []byte // Name []byte
CredentialID []byte // CredentialID []byte
KeyValuePairs []string `asn1:"optional,omitempty"` // KeyValuePairs []string `asn1:"optional,omitempty"`
} // }
type forceCNOption struct { type forceCNOption struct {
ForceCN bool ForceCN bool
@ -464,23 +464,22 @@ func (o *forceCNOption) Modify(cert *x509.Certificate, _ SignOptions) error {
} }
type provisionerExtensionOption struct { type provisionerExtensionOption struct {
Type int Extension
Name string
CredentialID string
KeyValuePairs []string
} }
func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption { func newProvisionerExtensionOption(typ Type, name, credentialID string, keyValuePairs ...string) *provisionerExtensionOption {
return &provisionerExtensionOption{ return &provisionerExtensionOption{
Type: int(typ), Extension: Extension{
Name: name, Type: typ,
CredentialID: credentialID, Name: name,
KeyValuePairs: keyValuePairs, CredentialID: credentialID,
KeyValuePairs: keyValuePairs,
},
} }
} }
func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error { func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOptions) error {
ext, err := createProvisionerExtension(o.Type, o.Name, o.CredentialID, o.KeyValuePairs...) ext, err := o.ToExtension()
if err != nil { if err != nil {
return errs.NewError(http.StatusInternalServerError, err, "error creating certificate") return errs.NewError(http.StatusInternalServerError, err, "error creating certificate")
} }
@ -494,20 +493,3 @@ func (o *provisionerExtensionOption) Modify(cert *x509.Certificate, _ SignOption
cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...) cert.ExtraExtensions = append([]pkix.Extension{ext}, cert.ExtraExtensions...)
return nil return nil
} }
func createProvisionerExtension(typ int, name, credentialID string, keyValuePairs ...string) (pkix.Extension, error) {
b, err := asn1.Marshal(stepProvisionerASN1{
Type: typ,
Name: []byte(name),
CredentialID: []byte(credentialID),
KeyValuePairs: keyValuePairs,
})
if err != nil {
return pkix.Extension{}, errors.Wrap(err, "error marshaling provisioner extension")
}
return pkix.Extension{
Id: stepOIDProvisioner,
Critical: false,
Value: b,
}, nil
}

View file

@ -636,18 +636,18 @@ func Test_newProvisionerExtension_Option(t *testing.T) {
valid: func(cert *x509.Certificate) { valid: func(cert *x509.Certificate) {
if assert.Len(t, 1, cert.ExtraExtensions) { if assert.Len(t, 1, cert.ExtraExtensions) {
ext := cert.ExtraExtensions[0] ext := cert.ExtraExtensions[0]
assert.Equals(t, ext.Id, stepOIDProvisioner) assert.Equals(t, ext.Id, StepOIDProvisioner)
} }
}, },
} }
}, },
"ok/prepend": func() test { "ok/prepend": func() test {
return test{ return test{
cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: stepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}}, cert: &x509.Certificate{ExtraExtensions: []pkix.Extension{{Id: StepOIDProvisioner, Critical: true}, {Id: []int{1, 2, 3}}}},
valid: func(cert *x509.Certificate) { valid: func(cert *x509.Certificate) {
if assert.Len(t, 3, cert.ExtraExtensions) { if assert.Len(t, 3, cert.ExtraExtensions) {
ext := cert.ExtraExtensions[0] ext := cert.ExtraExtensions[0]
assert.Equals(t, ext.Id, stepOIDProvisioner) assert.Equals(t, ext.Id, StepOIDProvisioner)
assert.False(t, ext.Critical) assert.False(t, ext.Critical)
} }
}, },

View file

@ -685,7 +685,7 @@ func Test_sshCertDefaultValidator_Valid(t *testing.T) {
func Test_sshCertValidityValidator(t *testing.T) { func Test_sshCertValidityValidator(t *testing.T) {
p, err := generateX5C(nil) p, err := generateX5C(nil)
assert.FatalError(t, err) assert.FatalError(t, err)
v := sshCertValidityValidator{p.claimer} v := sshCertValidityValidator{p.ctl.Claimer}
n := now() n := now()
tests := []struct { tests := []struct {
name string name string
@ -806,7 +806,7 @@ func Test_sshValidityModifier(t *testing.T) {
tests := map[string]func() test{ tests := map[string]func() test{
"fail/type-not-set": func() test { "fail/type-not-set": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
ValidBefore: uint64(n.Add(8 * time.Hour).Unix()), ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
@ -816,7 +816,7 @@ func Test_sshValidityModifier(t *testing.T) {
}, },
"fail/type-not-recognized": func() test { "fail/type-not-recognized": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(6 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 4, CertType: 4,
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
@ -827,7 +827,7 @@ func Test_sshValidityModifier(t *testing.T) {
}, },
"fail/requested-validAfter-after-limit": func() test { "fail/requested-validAfter-after-limit": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: uint64(n.Add(2 * time.Hour).Unix()), ValidAfter: uint64(n.Add(2 * time.Hour).Unix()),
@ -838,7 +838,7 @@ func Test_sshValidityModifier(t *testing.T) {
}, },
"fail/requested-validBefore-after-limit": func() test { "fail/requested-validBefore-after-limit": func() test {
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(1 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: uint64(n.Unix()), ValidAfter: uint64(n.Unix()),
@ -850,7 +850,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/no-limit": func() test { "ok/no-limit": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
}, },
@ -863,7 +863,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/defaults": func() test { "ok/defaults": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix()) va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
}, },
@ -876,7 +876,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/valid-requested-validBefore": func() test { "ok/valid-requested-validBefore": func() test {
va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix()) va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: va, ValidAfter: va,
@ -891,7 +891,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/empty-requested-validBefore-limit-after-default": func() test { "ok/empty-requested-validBefore-limit-after-default": func() test {
va := uint64(n.Unix()) va := uint64(n.Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(24 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: va, ValidAfter: va,
@ -905,7 +905,7 @@ func Test_sshValidityModifier(t *testing.T) {
"ok/empty-requested-validBefore-limit-before-default": func() test { "ok/empty-requested-validBefore-limit-before-default": func() test {
va := uint64(n.Unix()) va := uint64(n.Unix())
return test{ return test{
svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)}, svm: &sshLimitDuration{Claimer: p.ctl.Claimer, NotAfter: n.Add(3 * time.Hour)},
cert: &ssh.Certificate{ cert: &ssh.Certificate{
CertType: 1, CertType: 1,
ValidAfter: va, ValidAfter: va,

View file

@ -29,8 +29,7 @@ type SSHPOP struct {
Type string `json:"type"` Type string `json:"type"`
Name string `json:"name"` Name string `json:"name"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
claimer *Claimer ctl *Controller
audiences Audiences
sshPubKeys *SSHKeys sshPubKeys *SSHKeys
} }
@ -83,7 +82,7 @@ func (p *SSHPOP) GetEncryptedKey() (string, string, bool) {
} }
// Init initializes and validates the fields of a SSHPOP type. // Init initializes and validates the fields of a SSHPOP type.
func (p *SSHPOP) Init(config Config) error { func (p *SSHPOP) Init(config Config) (err error) {
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -93,17 +92,15 @@ func (p *SSHPOP) Init(config Config) error {
return errors.New("provisioner public SSH validation keys cannot be empty") return errors.New("provisioner public SSH validation keys cannot be empty")
} }
// Update claims with global ones
var err error
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// TODO(hs): initialize the policy engine and add it as an SSH cert validator // TODO(hs): initialize the policy engine and add it as an SSH cert validator
p.audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.sshPubKeys = config.SSHKeys p.sshPubKeys = config.SSHKeys
return nil
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
// Update claims with global ones
p.ctl, err = NewController(p, p.Claims, config)
return
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -111,7 +108,7 @@ func (p *SSHPOP) Init(config Config) error {
// e.g. a Sign request will auth/validate different fields than a Revoke request. // e.g. a Sign request will auth/validate different fields than a Revoke request.
// //
// Checking for certificate revocation has been moved to the authority package. // Checking for certificate revocation has been moved to the authority package.
func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayload, error) { func (p *SSHPOP) authorizeToken(token string, audiences []string, checkValidity bool) (*sshPOPPayload, error) {
sshCert, jwt, err := ExtractSSHPOPCert(token) sshCert, jwt, err := ExtractSSHPOPCert(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, return nil, errs.Wrap(http.StatusUnauthorized, err,
@ -119,13 +116,18 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
} }
// Check validity period of the certificate. // Check validity period of the certificate.
n := time.Now() //
if sshCert.ValidAfter != 0 && time.Unix(int64(sshCert.ValidAfter), 0).After(n) { // Controller.AuthorizeSSHRenew will validate this on the renewal flow.
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future") if checkValidity {
} unixNow := time.Now().Unix()
if sshCert.ValidBefore != 0 && time.Unix(int64(sshCert.ValidBefore), 0).Before(n) { if after := int64(sshCert.ValidAfter); after < 0 || unixNow < int64(sshCert.ValidAfter) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past") return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validAfter is in the future")
}
if before := int64(sshCert.ValidBefore); sshCert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) {
return nil, errs.Unauthorized("sshpop.authorizeToken; sshpop certificate validBefore is in the past")
}
} }
sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey) sshCryptoPubKey, ok := sshCert.Key.(ssh.CryptoPublicKey)
if !ok { if !ok {
return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey") return nil, errs.InternalServer("sshpop.authorizeToken; sshpop public key could not be cast to ssh CryptoPublicKey")
@ -188,7 +190,7 @@ func (p *SSHPOP) authorizeToken(token string, audiences []string) (*sshPOPPayloa
// AuthorizeSSHRevoke validates the authorization token and extracts/validates // AuthorizeSSHRevoke validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
claims, err := p.authorizeToken(token, p.audiences.SSHRevoke) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke, true)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRevoke")
} }
@ -201,22 +203,20 @@ func (p *SSHPOP) AuthorizeSSHRevoke(ctx context.Context, token string) error {
// AuthorizeSSHRenew validates the authorization token and extracts/validates // AuthorizeSSHRenew validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { func (p *SSHPOP) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRenew) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRenew, false)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew") return nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRenew")
} }
if claims.sshCert.CertType != ssh.HostCert { if claims.sshCert.CertType != ssh.HostCert {
return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate") return nil, errs.BadRequest("sshpop certificate must be a host ssh certificate")
} }
return claims.sshCert, p.ctl.AuthorizeSSHRenew(ctx, claims.sshCert)
return claims.sshCert, nil
} }
// AuthorizeSSHRekey validates the authorization token and extracts/validates // AuthorizeSSHRekey validates the authorization token and extracts/validates
// the SSH certificate from the ssh-pop header. // the SSH certificate from the ssh-pop header.
func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) { func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Certificate, []SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.SSHRekey) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHRekey, true)
if err != nil { if err != nil {
return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey") return nil, nil, errs.Wrap(http.StatusInternalServerError, err, "sshpop.AuthorizeSSHRekey")
} }
@ -227,11 +227,10 @@ func (p *SSHPOP) AuthorizeSSHRekey(ctx context.Context, token string) (*ssh.Cert
// Validate public key // Validate public key
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
}, nil }, nil
} }
// ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate // ExtractSSHPOPCert parses a JWT and extracts and loads the SSH Certificate

View file

@ -38,6 +38,7 @@ func TestSSHPOP_Getters(t *testing.T) {
} }
func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) { func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, *jose.JSONWebKey, error) {
now := time.Now()
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "foo", 0)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -46,6 +47,12 @@ func createSSHCert(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate,
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if cert.ValidAfter == 0 {
cert.ValidAfter = uint64(now.Unix())
}
if cert.ValidBefore == 0 {
cert.ValidBefore = uint64(now.Add(time.Hour).Unix())
}
if err := cert.SignCert(rand.Reader, signer); err != nil { if err := cert.SignCert(rand.Reader, signer); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -207,7 +214,7 @@ func TestSSHPOP_authorizeToken(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil { if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign, true); err != nil {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tc.code) assert.Equals(t, sc.StatusCode(), tc.code)
@ -455,7 +462,7 @@ func TestSSHPOP_AuthorizeSSHRekey(t *testing.T) {
case *sshDefaultPublicKeyValidator: case *sshDefaultPublicKeyValidator:
case *sshCertDefaultValidator: case *sshCertDefaultValidator:
case *sshCertValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
default: default:
assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, errors.Errorf("unexpected sign option of type %T", v))
} }

View file

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDeTCCAx+gAwIBAgIRAOTItW2pYuSU+PkmLW090iUwCgYIKoZIzj0EAwIwJDEi
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjUy
MjBaFw0yMjAzMTIyMjUzMjBaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
KoZIzj0DAQcDQgAE/9vvOZ1Zzysnf3VeGyotMJEMZdAborB36Ah5QL/3yQNMRWIc
pv9Dwx19pHw7SquVE8jIaPPJSjaeWnfMPDYDxaOCAbcwggGzMA4GA1UdDwEB/wQE
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
ADAdBgNVHQ4EFgQUkJUg6AsqWlqTZt6BHidRMwh1vKYwHwYDVR0jBBgwFoAUDpTg
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
LXNlcnZlci1nMy5jcmwwFwYMKwYBBAGCpGTGKEABBAdmb29vYmFyMAoGCCqGSM49
BAMCA0gAMEUCIQCWYqOuk4bLkVVeHvo3P8TlJJ3fw6ijDDLstvdrQqAl5wIgEjSY
wVcR649Oc8PJGh/43Kpx0+4OTYPQrD/JqphVF7g=
-----END CERTIFICATE-----

View file

@ -0,0 +1,22 @@
-----BEGIN CERTIFICATE-----
MIIDujCCA2GgAwIBAgIRAM5celDKTTqAGycljO7FZdEwCgYIKoZIzj0EAwIwJDEi
MCAGA1UEAxMZU21hbGxzdGVwIEludGVybWVkaWF0ZSBDQTAeFw0yMjAzMTEyMjQx
MDRaFw0yMjAzMTIyMjQyMDRaMIGcMQswCQYDVQQGEwJDSDETMBEGA1UECBMKQ2Fs
aWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEYMBYGA1UECRMPMSBUaGUg
U3RyZWV0IFN0MRMwEQYDVQQKDAo8bm8gdmFsdWU+MRYwFAYDVQQLEw1TbWFsbHN0
ZXAgRW5nMRkwFwYDVQQDDBB0ZXN0QGV4YW1wbGUuY29tMFkwEwYHKoZIzj0CAQYI
KoZIzj0DAQcDQgAEkXffZYlSJRMxJrZHmUpEMC4jQYCkF86mLJY0iLZ8k00N/xF0
4rAGwzTU/l9tfRpNl+z/XfMMWPXS0Q8NU/o4S6OCAfkwggH1MA4GA1UdDwEB/wQE
AwIHgDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIw
ADAdBgNVHQ4EFgQUL3sSlYW8Tf2l2P+gFTdn5wsUjfgwHwYDVR0jBBgwFoAUDpTg
d3VFCn6e71wXcwbDCURBomUwgZoGCCsGAQUFBwEBBIGNMIGKMBcGCCsGAQUFBzAB
hgtodHRwczovL2ZvbzBvBggrBgEFBQcwAoZjaHR0cHM6Ly9jYS5zbWFsbHN0ZXAu
Y29tOjkwMDAvcm9vdC9hNzhhODUwMDI1YzBjMjM0Mzg1ZWRhMjNkNzE5Mjk2NGNh
NTZhYTlkNzI3ZjUzNTY1M2IwYWZiODFjMWUwNTU5MBsGA1UdEQQUMBKBEHRlc3RA
ZXhhbXBsZS5jb20wIAYDVR0gBBkwFzALBglghkgBhv1sAQEwCAYGZ4EMAQICMD8G
A1UdHwQ4MDYwNKAyoDCGLmh0dHA6Ly9jcmwzLmRpZ2ljZXJ0LmNvbS9zaGEyLWV2
LXNlcnZlci1nMy5jcmwwWQYMKwYBBAGCpGTGKEABBEkwRwIBAQQVbWFyaWFub0Bz
bWFsbHN0ZXAuY29tBCtudmduUjh3U3pwVWxydF90QzNtdnJod2hCeDlZN1QxV0xf
SmpjRlZXWUJRMAoGCCqGSM49BAMCA0cAMEQCIE6umrhSbeQWWVK5cWBvXj5c0cGB
bUF0rNw/dsaCaWcwAiAKSkmjhsC63DVPXPCNUki90YgVovO69foO1ZaB43lx5w==
-----END CERTIFICATE-----

View file

@ -24,20 +24,22 @@ import (
) )
var ( var (
defaultDisableRenewal = false defaultDisableRenewal = false
defaultEnableSSHCA = true defaultAllowRenewAfterExpiry = false
globalProvisionerClaims = Claims{ defaultEnableSSHCA = true
MinTLSDur: &Duration{5 * time.Minute}, globalProvisionerClaims = Claims{
MaxTLSDur: &Duration{24 * time.Hour}, MinTLSDur: &Duration{5 * time.Minute},
DefaultTLSDur: &Duration{24 * time.Hour}, MaxTLSDur: &Duration{24 * time.Hour},
DisableRenewal: &defaultDisableRenewal, DefaultTLSDur: &Duration{24 * time.Hour},
MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs MinUserSSHDur: &Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &Duration{Duration: 24 * time.Hour}, MaxUserSSHDur: &Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour}, DefaultUserSSHDur: &Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs MinHostSSHDur: &Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, MaxHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour}, DefaultHostSSHDur: &Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &defaultEnableSSHCA, EnableSSHCA: &defaultEnableSSHCA,
DisableRenewal: &defaultDisableRenewal,
AllowRenewAfterExpiry: &defaultAllowRenewAfterExpiry,
} }
testAudiences = Audiences{ testAudiences = Audiences{
Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"}, Sign: []string{"https://ca.smallstep.com/1.0/sign", "https://ca.smallstep.com/sign"},
@ -172,19 +174,18 @@ func generateJWK() (*JWK, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil { p := &JWK{
return nil, err
}
return &JWK{
Name: name, Name: name,
Type: "JWK", Type: "JWK",
Key: &public, Key: &public,
EncryptedKey: encrypted, EncryptedKey: encrypted,
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences, }
claimer: claimer, p.ctl, err = NewController(p, p.Claims, Config{
}, nil Audiences: testAudiences,
})
return p, err
} }
func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) { func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
@ -205,23 +206,21 @@ func generateK8sSA(inputPubKey interface{}) (*K8sSA, error) {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
pubKeys := []interface{}{fooPub, barPub} pubKeys := []interface{}{fooPub, barPub}
if inputPubKey != nil { if inputPubKey != nil {
pubKeys = append(pubKeys, inputPubKey) pubKeys = append(pubKeys, inputPubKey)
} }
return &K8sSA{ p := &K8sSA{
Name: K8sSAName, Name: K8sSAName,
Type: "K8sSA", Type: "K8sSA",
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences, pubKeys: pubKeys,
claimer: claimer, }
pubKeys: pubKeys, p.ctl, err = NewController(p, p.Claims, Config{
}, nil Audiences: testAudiences,
})
return p, err
} }
func generateSSHPOP() (*SSHPOP, error) { func generateSSHPOP() (*SSHPOP, error) {
@ -229,11 +228,6 @@ func generateSSHPOP() (*SSHPOP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub") userB, err := os.ReadFile("./testdata/certs/ssh_user_ca_key.pub")
if err != nil { if err != nil {
return nil, err return nil, err
@ -251,17 +245,19 @@ func generateSSHPOP() (*SSHPOP, error) {
return nil, err return nil, err
} }
return &SSHPOP{ p := &SSHPOP{
Name: name, Name: name,
Type: "SSHPOP", Type: "SSHPOP",
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences,
claimer: claimer,
sshPubKeys: &SSHKeys{ sshPubKeys: &SSHKeys{
UserKeys: []ssh.PublicKey{userKey}, UserKeys: []ssh.PublicKey{userKey},
HostKeys: []ssh.PublicKey{hostKey}, HostKeys: []ssh.PublicKey{hostKey},
}, },
}, nil }
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
} }
func generateX5C(root []byte) (*X5C, error) { func generateX5C(root []byte) (*X5C, error) {
@ -283,11 +279,6 @@ M46l92gdOozT
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
rootPool := x509.NewCertPool() rootPool := x509.NewCertPool()
var ( var (
@ -305,15 +296,17 @@ M46l92gdOozT
} }
rootPool.AddCert(cert) rootPool.AddCert(cert)
} }
return &X5C{ p := &X5C{
Name: name, Name: name,
Type: "X5C", Type: "X5C",
Roots: root, Roots: root,
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
audiences: testAudiences, rootPool: rootPool,
claimer: claimer, }
rootPool: rootPool, p.ctl, err = NewController(p, p.Claims, Config{
}, nil Audiences: testAudiences,
})
return p, err
} }
func generateOIDC() (*OIDC, error) { func generateOIDC() (*OIDC, error) {
@ -333,11 +326,7 @@ func generateOIDC() (*OIDC, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims) p := &OIDC{
if err != nil {
return nil, err
}
return &OIDC{
Name: name, Name: name,
Type: "OIDC", Type: "OIDC",
ClientID: clientID, ClientID: clientID,
@ -351,8 +340,11 @@ func generateOIDC() (*OIDC, error) {
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
claimer: claimer, }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
} }
func generateGCP() (*GCP, error) { func generateGCP() (*GCP, error) {
@ -368,23 +360,21 @@ func generateGCP() (*GCP, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims) p := &GCP{
if err != nil {
return nil, err
}
return &GCP{
Type: "GCP", Type: "GCP",
Name: name, Name: name,
ServiceAccounts: []string{serviceAccount}, ServiceAccounts: []string{serviceAccount},
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
claimer: claimer,
config: newGCPConfig(), config: newGCPConfig(),
keyStore: &keyStore{ keyStore: &keyStore{
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
audiences: testAudiences.WithFragment("gcp/" + name), }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("gcp/" + name),
})
return p, err
} }
func generateAWS() (*AWS, error) { func generateAWS() (*AWS, error) {
@ -396,10 +386,6 @@ func generateAWS() (*AWS, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(awsTestCertificate)) block, _ := pem.Decode([]byte(awsTestCertificate))
if block == nil || block.Type != "CERTIFICATE" { if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate") return nil, errors.New("error decoding AWS certificate")
@ -408,13 +394,12 @@ func generateAWS() (*AWS, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate") return nil, errors.Wrap(err, "error parsing AWS certificate")
} }
return &AWS{ p := &AWS{
Type: "AWS", Type: "AWS",
Name: name, Name: name,
Accounts: []string{accountID}, Accounts: []string{accountID},
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
IMDSVersions: []string{"v2", "v1"}, IMDSVersions: []string{"v2", "v1"},
claimer: claimer,
config: &awsConfig{ config: &awsConfig{
identityURL: awsIdentityURL, identityURL: awsIdentityURL,
signatureURL: awsSignatureURL, signatureURL: awsSignatureURL,
@ -423,8 +408,11 @@ func generateAWS() (*AWS, error) {
certificates: []*x509.Certificate{cert}, certificates: []*x509.Certificate{cert},
signatureAlgorithm: awsSignatureAlgorithm, signatureAlgorithm: awsSignatureAlgorithm,
}, },
audiences: testAudiences.WithFragment("aws/" + name), }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("aws/" + name),
})
return p, err
} }
func generateAWSWithServer() (*AWS, *httptest.Server, error) { func generateAWSWithServer() (*AWS, *httptest.Server, error) {
@ -505,10 +493,6 @@ func generateAWSV1Only() (*AWS, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
block, _ := pem.Decode([]byte(awsTestCertificate)) block, _ := pem.Decode([]byte(awsTestCertificate))
if block == nil || block.Type != "CERTIFICATE" { if block == nil || block.Type != "CERTIFICATE" {
return nil, errors.New("error decoding AWS certificate") return nil, errors.New("error decoding AWS certificate")
@ -517,13 +501,12 @@ func generateAWSV1Only() (*AWS, error) {
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error parsing AWS certificate") return nil, errors.Wrap(err, "error parsing AWS certificate")
} }
return &AWS{ p := &AWS{
Type: "AWS", Type: "AWS",
Name: name, Name: name,
Accounts: []string{accountID}, Accounts: []string{accountID},
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
IMDSVersions: []string{"v1"}, IMDSVersions: []string{"v1"},
claimer: claimer,
config: &awsConfig{ config: &awsConfig{
identityURL: awsIdentityURL, identityURL: awsIdentityURL,
signatureURL: awsSignatureURL, signatureURL: awsSignatureURL,
@ -532,8 +515,11 @@ func generateAWSV1Only() (*AWS, error) {
certificates: []*x509.Certificate{cert}, certificates: []*x509.Certificate{cert},
signatureAlgorithm: awsSignatureAlgorithm, signatureAlgorithm: awsSignatureAlgorithm,
}, },
audiences: testAudiences.WithFragment("aws/" + name), }
}, nil p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences.WithFragment("aws/" + name),
})
return p, err
} }
func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) { func generateAWSWithServerV1Only() (*AWS, *httptest.Server, error) {
@ -600,21 +586,16 @@ func generateAzure() (*Azure, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
claimer, err := NewClaimer(nil, globalProvisionerClaims)
if err != nil {
return nil, err
}
jwk, err := generateJSONWebKey() jwk, err := generateJSONWebKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Azure{ p := &Azure{
Type: "Azure", Type: "Azure",
Name: name, Name: name,
TenantID: tenantID, TenantID: tenantID,
Audience: azureDefaultAudience, Audience: azureDefaultAudience,
Claims: &globalProvisionerClaims, Claims: &globalProvisionerClaims,
claimer: claimer,
config: newAzureConfig(tenantID), config: newAzureConfig(tenantID),
oidcConfig: openIDConfiguration{ oidcConfig: openIDConfiguration{
Issuer: "https://sts.windows.net/" + tenantID + "/", Issuer: "https://sts.windows.net/" + tenantID + "/",
@ -624,7 +605,11 @@ func generateAzure() (*Azure, error) {
keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}}, keySet: jose.JSONWebKeySet{Keys: []jose.JSONWebKey{*jwk}},
expiry: time.Now().Add(24 * time.Hour), expiry: time.Now().Add(24 * time.Hour),
}, },
}, nil }
p.ctl, err = NewController(p, p.Claims, Config{
Audiences: testAudiences,
})
return p, err
} }
func generateAzureWithServer() (*Azure, *httptest.Server, error) { func generateAzureWithServer() (*Azure, *httptest.Server, error) {
@ -671,7 +656,7 @@ func generateAzureWithServer() (*Azure, *httptest.Server, error) {
w.Header().Add("Cache-Control", "max-age=5") w.Header().Add("Cache-Control", "max-age=5")
writeJSON(w, getPublic(az.keyStore.keySet)) writeJSON(w, getPublic(az.keyStore.keySet))
case "/metadata/identity/oauth2/token": case "/metadata/identity/oauth2/token":
tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", time.Now(), &az.keyStore.keySet.Keys[0]) tok, err := generateAzureToken("subject", issuer, "https://management.azure.com/", az.TenantID, "subscriptionID", "resourceGroup", "virtualMachine", "vm", time.Now(), &az.keyStore.keySet.Keys[0])
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
} else { } else {
@ -1009,7 +994,7 @@ func generateAWSToken(p *AWS, sub, iss, aud, accountID, instanceID, privateIP, r
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }
func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, virtualMachine string, iat time.Time, jwk *jose.JSONWebKey) (string, error) { func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup, resourceName, resourceType string, iat time.Time, jwk *jose.JSONWebKey) (string, error) {
sig, err := jose.NewSigner( sig, err := jose.NewSigner(
jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key}, jose.SigningKey{Algorithm: jose.ES256, Key: jwk.Key},
new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID), new(jose.SignerOptions).WithType("JWT").WithHeader("kid", jwk.KeyID),
@ -1017,6 +1002,12 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
if err != nil { if err != nil {
return "", err return "", err
} }
var xmsMirID string
if resourceType == "vm" {
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, resourceName)
} else if resourceType == "uai" {
xmsMirID = fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.ManagedIdentity/userAssignedIdentities/%s", subscriptionID, resourceGroup, resourceName)
}
claims := azurePayload{ claims := azurePayload{
Claims: jose.Claims{ Claims: jose.Claims{
@ -1034,7 +1025,7 @@ func generateAzureToken(sub, iss, aud, tenantID, subscriptionID, resourceGroup,
ObjectID: "the-oid", ObjectID: "the-oid",
TenantID: tenantID, TenantID: tenantID,
Version: "the-version", Version: "the-version",
XMSMirID: fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Compute/virtualMachines/%s", subscriptionID, resourceGroup, virtualMachine), XMSMirID: xmsMirID,
} }
return jose.Signed(sig).Claims(claims).CompactSerialize() return jose.Signed(sig).Claims(claims).CompactSerialize()
} }

View file

@ -33,8 +33,7 @@ type X5C struct {
Roots []byte `json:"roots"` Roots []byte `json:"roots"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
claimer *Claimer ctl *Controller
audiences Audiences
rootPool *x509.CertPool rootPool *x509.CertPool
x509Policy policy.X509Policy x509Policy policy.X509Policy
sshHostPolicy policy.HostPolicy sshHostPolicy policy.HostPolicy
@ -90,7 +89,7 @@ func (p *X5C) GetEncryptedKey() (string, string, bool) {
} }
// Init initializes and validates the fields of a X5C type. // Init initializes and validates the fields of a X5C type.
func (p *X5C) Init(config Config) error { func (p *X5C) Init(config Config) (err error) {
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -105,6 +104,7 @@ func (p *X5C) Init(config Config) error {
var ( var (
block *pem.Block block *pem.Block
rest = p.Roots rest = p.Roots
count int
) )
for rest != nil { for rest != nil {
block, rest = pem.Decode(rest) block, rest = pem.Decode(rest)
@ -115,20 +115,15 @@ func (p *X5C) Init(config Config) error {
if err != nil { if err != nil {
return errors.Wrap(err, "error parsing x509 certificate from PEM block") return errors.Wrap(err, "error parsing x509 certificate from PEM block")
} }
count++
p.rootPool.AddCert(cert) p.rootPool.AddCert(cert)
} }
// Verify that at least one root was found. // Verify that at least one root was found.
if len(p.rootPool.Subjects()) == 0 { if count == 0 {
return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName()) return errors.Errorf("no x509 certificates found in roots attribute for provisioner '%s'", p.GetName())
} }
// Update claims with global ones
var err error
if p.claimer, err = NewClaimer(p.Claims, config.Claims); err != nil {
return err
}
// Initialize the x509 allow/deny policy engine // Initialize the x509 allow/deny policy engine
if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil { if p.x509Policy, err = policy.NewX509PolicyEngine(p.Options.GetX509Options()); err != nil {
return err return err
@ -144,8 +139,9 @@ func (p *X5C) Init(config Config) error {
return err return err
} }
p.audiences = config.Audiences.WithFragment(p.GetIDForToken()) config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
return nil p.ctl, err = NewController(p, p.Claims, config)
return
} }
// authorizeToken performs common jwt authorization actions and returns the // authorizeToken performs common jwt authorization actions and returns the
@ -208,13 +204,13 @@ func (p *X5C) authorizeToken(token string, audiences []string) (*x5cPayload, err
// AuthorizeRevoke returns an error if the provisioner does not have rights to // AuthorizeRevoke returns an error if the provisioner does not have rights to
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error { func (p *X5C) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeRevoke")
} }
// AuthorizeSign validates the given token. // AuthorizeSign validates the given token.
func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
claims, err := p.authorizeToken(token, p.audiences.Sign) claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign")
} }
@ -246,32 +242,31 @@ func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeX5C, p.Name, ""), newProvisionerExtensionOption(TypeX5C, p.Name, ""),
profileLimitDuration{p.claimer.DefaultTLSCertDuration(), profileLimitDuration{
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter}, p.ctl.Claimer.DefaultTLSCertDuration(),
claims.chains[0][0].NotBefore, claims.chains[0][0].NotAfter,
},
// validators // validators
commonNameValidator(claims.Subject), commonNameValidator(claims.Subject),
defaultSANsValidator(claims.SANs), defaultSANsValidator(claims.SANs),
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.claimer.MinTLSCertDuration(), p.claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.x509Policy), newX509NamePolicyValidator(p.x509Policy),
}, nil }, nil
} }
// AuthorizeRenew returns an error if the renewal is disabled. // AuthorizeRenew returns an error if the renewal is disabled.
func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *X5C) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
if p.claimer.IsDisableRenewal() { return p.ctl.AuthorizeRenew(ctx, cert)
return errs.Unauthorized("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName())
}
return nil
} }
// AuthorizeSSHSign returns the list of SignOption for a SignSSH request. // AuthorizeSSHSign returns the list of SignOption for a SignSSH request.
func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
if !p.claimer.IsSSHCAEnabled() { if !p.ctl.Claimer.IsSSHCAEnabled() {
return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName()) return nil, errs.Unauthorized("x5c.AuthorizeSSHSign; sshCA is disabled for x5c provisioner '%s'", p.GetName())
} }
claims, err := p.authorizeToken(token, p.audiences.SSHSign) claims, err := p.authorizeToken(token, p.ctl.Audiences.SSHSign)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign") return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSSHSign")
} }
@ -334,11 +329,11 @@ func (p *X5C) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
return append(signOptions, return append(signOptions,
// Checks the validity bounds, and set the validity if has not been set. // Checks the validity bounds, and set the validity if has not been set.
&sshLimitDuration{p.claimer, claims.chains[0][0].NotAfter}, &sshLimitDuration{p.ctl.Claimer, claims.chains[0][0].NotAfter},
// Validate public key. // Validate public key.
&sshDefaultPublicKeyValidator{}, &sshDefaultPublicKeyValidator{},
// Validate the validity period. // Validate the validity period.
&sshCertValidityValidator{p.claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed // Ensure that all principal names are allowed

View file

@ -2,6 +2,7 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509"
"net/http" "net/http"
"testing" "testing"
"time" "time"
@ -69,7 +70,7 @@ func TestX5C_Init(t *testing.T) {
}, },
"fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest { "fail/no-valid-root-certs": func(t *testing.T) ProvisionerValidateTest {
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo"), audiences: testAudiences}, p: &X5C{Name: "foo", Type: "bar", Roots: []byte("foo")},
err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"), err: errors.Errorf("no x509 certificates found in roots attribute for provisioner 'foo'"),
} }
}, },
@ -117,6 +118,8 @@ M46l92gdOozT
return ProvisionerValidateTest{ return ProvisionerValidateTest{
p: p, p: p,
extraValid: func(p *X5C) error { extraValid: func(p *X5C) error {
// nolint:staticcheck // We don't have a different way to
// check the number of certificates in the pool.
numCerts := len(p.rootPool.Subjects()) numCerts := len(p.rootPool.Subjects())
if numCerts != 2 { if numCerts != 2 {
return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts) return errors.Errorf("unexpected number of certs: want 2, but got %d", numCerts)
@ -141,7 +144,7 @@ M46l92gdOozT
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, tc.p.audiences, config.Audiences.WithFragment(tc.p.GetID())) assert.Equals(t, *tc.p.ctl.Audiences, config.Audiences.WithFragment(tc.p.GetID()))
if tc.extraValid != nil { if tc.extraValid != nil {
assert.Nil(t, tc.extraValid(tc.p)) assert.Nil(t, tc.extraValid(tc.p))
} }
@ -468,13 +471,13 @@ func TestX5C_AuthorizeSign(t *testing.T) {
switch v := o.(type) { switch v := o.(type) {
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, int(TypeX5C)) assert.Equals(t, v.Type, TypeX5C)
assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.Name, tc.p.GetName())
assert.Equals(t, v.CredentialID, "") assert.Equals(t, v.CredentialID, "")
assert.Len(t, 0, v.KeyValuePairs) assert.Len(t, 0, v.KeyValuePairs)
case profileLimitDuration: case profileLimitDuration:
assert.Equals(t, v.def, tc.p.claimer.DefaultTLSCertDuration()) assert.Equals(t, v.def, tc.p.ctl.Claimer.DefaultTLSCertDuration())
claims, err := tc.p.authorizeToken(tc.token, tc.p.audiences.Sign) claims, err := tc.p.authorizeToken(tc.token, tc.p.ctl.Audiences.Sign)
assert.FatalError(t, err) assert.FatalError(t, err)
assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter) assert.Equals(t, v.notAfter, claims.chains[0][0].NotAfter)
case commonNameValidator: case commonNameValidator:
@ -483,8 +486,8 @@ func TestX5C_AuthorizeSign(t *testing.T) {
case defaultSANsValidator: case defaultSANsValidator:
assert.Equals(t, []string(v), tc.sans) assert.Equals(t, []string(v), tc.sans)
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator: case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine) assert.Equals(t, nil, v.policyEngine)
default: default:
@ -552,6 +555,7 @@ func TestX5C_AuthorizeRevoke(t *testing.T) {
} }
func TestX5C_AuthorizeRenew(t *testing.T) { func TestX5C_AuthorizeRenew(t *testing.T) {
now := time.Now().Truncate(time.Second)
type test struct { type test struct {
p *X5C p *X5C
code int code int
@ -564,12 +568,12 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
// disable renewal // disable renewal
disable := true disable := true
p.Claims = &Claims{DisableRenewal: &disable} p.Claims = &Claims{DisableRenewal: &disable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
err: errors.Errorf("x5c.AuthorizeRenew; renew is disabled for x5c provisioner '%s'", p.GetName()), err: errors.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
@ -583,7 +587,10 @@ func TestX5C_AuthorizeRenew(t *testing.T) {
for name, tt := range tests { for name, tt := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := tt(t) tc := tt(t)
if err := tc.p.AuthorizeRenew(context.Background(), nil); err != nil { if err := tc.p.AuthorizeRenew(context.Background(), &x509.Certificate{
NotBefore: now,
NotAfter: now.Add(time.Hour),
}); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
sc, ok := err.(errs.StatusCoder) sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface") assert.Fatal(t, ok, "error does not implement StatusCoder interface")
@ -619,7 +626,7 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
// disable sshCA // disable sshCA
enable := false enable := false
p.Claims = &Claims{EnableSSHCA: &enable} p.Claims = &Claims{EnableSSHCA: &enable}
p.claimer, err = NewClaimer(p.Claims, globalProvisionerClaims) p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
p: p, p: p,
@ -775,10 +782,10 @@ func TestX5C_AuthorizeSSHSign(t *testing.T) {
case sshCertDefaultsModifier: case sshCertDefaultsModifier:
assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert}) assert.Equals(t, SignSSHOptions(v), SignSSHOptions{CertType: SSHUserCert})
case *sshLimitDuration: case *sshLimitDuration:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter) assert.Equals(t, v.NotAfter, x5cCerts[0].NotAfter)
case *sshCertValidityValidator: case *sshCertValidityValidator:
assert.Equals(t, v.Claimer, tc.p.claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshNamePolicyValidator: case *sshNamePolicyValidator:
assert.Equals(t, nil, v.userPolicyEngine) assert.Equals(t, nil, v.userPolicyEngine)
assert.Equals(t, nil, v.hostPolicyEngine) assert.Equals(t, nil, v.hostPolicyEngine)

View file

@ -109,7 +109,9 @@ func (a *Authority) generateProvisionerConfig(ctx context.Context) (provisioner.
UserKeys: sshKeys.UserKeys, UserKeys: sshKeys.UserKeys,
HostKeys: sshKeys.HostKeys, HostKeys: sshKeys.HostKeys,
}, },
GetIdentityFunc: a.getIdentityFunc, GetIdentityFunc: a.getIdentityFunc,
AuthorizeRenewFunc: a.authorizeRenewFunc,
AuthorizeSSHRenewFunc: a.authorizeSSHRenewFunc,
}, nil }, nil
} }
@ -488,7 +490,8 @@ func claimsToCertificates(c *linkedca.Claims) (*provisioner.Claims, error) {
} }
pc := &provisioner.Claims{ pc := &provisioner.Claims{
DisableRenewal: &c.DisableRenewal, DisableRenewal: &c.DisableRenewal,
AllowRenewAfterExpiry: &c.AllowRenewAfterExpiry,
} }
var err error var err error
@ -526,12 +529,18 @@ func claimsToLinkedca(c *provisioner.Claims) *linkedca.Claims {
} }
disableRenewal := config.DefaultDisableRenewal disableRenewal := config.DefaultDisableRenewal
allowRenewAfterExpiry := config.DefaultAllowRenewAfterExpiry
if c.DisableRenewal != nil { if c.DisableRenewal != nil {
disableRenewal = *c.DisableRenewal disableRenewal = *c.DisableRenewal
} }
if c.AllowRenewAfterExpiry != nil {
allowRenewAfterExpiry = *c.AllowRenewAfterExpiry
}
lc := &linkedca.Claims{ lc := &linkedca.Claims{
DisableRenewal: disableRenewal, DisableRenewal: disableRenewal,
AllowRenewAfterExpiry: allowRenewAfterExpiry,
} }
if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil { if c.DefaultTLSDur != nil || c.MinTLSDur != nil || c.MaxTLSDur != nil {

View file

@ -757,7 +757,7 @@ func TestAuthority_Renew(t *testing.T) {
now := time.Now().UTC() now := time.Now().UTC()
nb1 := now.Add(-time.Minute * 7) nb1 := now.Add(-time.Minute * 7)
na1 := now na1 := now.Add(time.Hour)
so := &provisioner.SignOptions{ so := &provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(nb1), NotBefore: provisioner.NewTimeDuration(nb1),
NotAfter: provisioner.NewTimeDuration(na1), NotAfter: provisioner.NewTimeDuration(na1),
@ -798,7 +798,20 @@ func TestAuthority_Renew(t *testing.T) {
"fail/unauthorized": func() (*renewTest, error) { "fail/unauthorized": func() (*renewTest, error) {
return &renewTest{ return &renewTest{
cert: certNoRenew, cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized,
}, nil
},
"fail/WithAuthorizeRenewFunc": func() (*renewTest, error) {
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
return errs.Unauthorized("not authorized")
}))
aa.x509CAService = a.x509CAService
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
return &renewTest{
auth: aa,
cert: cert,
err: errors.New("authority.Rekey: authority.authorizeRenew: not authorized"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
}, nil }, nil
}, },
@ -820,6 +833,17 @@ func TestAuthority_Renew(t *testing.T) {
cert: cert, cert: cert,
}, nil }, nil
}, },
"ok/WithAuthorizeRenewFunc": func() (*renewTest, error) {
aa := testAuthority(t, WithAuthorizeRenewFunc(func(ctx context.Context, p *provisioner.Controller, cert *x509.Certificate) error {
return nil
}))
aa.x509CAService = a.x509CAService
aa.config.AuthorityConfig.Template = a.config.AuthorityConfig.Template
return &renewTest{
auth: aa,
cert: cert,
}, nil
},
} }
for name, genTestCase := range tests { for name, genTestCase := range tests {
@ -856,7 +880,7 @@ func TestAuthority_Renew(t *testing.T) {
expiry := now.Add(time.Minute * 7) expiry := now.Add(time.Minute * 7)
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, leaf.Subject.String(), assert.Equals(t, leaf.Subject.String(),
@ -956,7 +980,7 @@ func TestAuthority_Rekey(t *testing.T) {
now := time.Now().UTC() now := time.Now().UTC()
nb1 := now.Add(-time.Minute * 7) nb1 := now.Add(-time.Minute * 7)
na1 := now na1 := now.Add(time.Hour)
so := &provisioner.SignOptions{ so := &provisioner.SignOptions{
NotBefore: provisioner.NewTimeDuration(nb1), NotBefore: provisioner.NewTimeDuration(nb1),
NotAfter: provisioner.NewTimeDuration(na1), NotAfter: provisioner.NewTimeDuration(na1),
@ -998,7 +1022,7 @@ func TestAuthority_Rekey(t *testing.T) {
"fail/unauthorized": func() (*renewTest, error) { "fail/unauthorized": func() (*renewTest, error) {
return &renewTest{ return &renewTest{
cert: certNoRenew, cert: certNoRenew,
err: errors.New("authority.Rekey: authority.authorizeRenew: jwk.AuthorizeRenew; renew is disabled for jwk provisioner 'dev'"), err: errors.New("authority.Rekey: authority.authorizeRenew: renew is disabled for provisioner 'dev'"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
}, nil }, nil
}, },
@ -1063,7 +1087,7 @@ func TestAuthority_Rekey(t *testing.T) {
expiry := now.Add(time.Minute * 7) expiry := now.Add(time.Minute * 7)
assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute))) assert.True(t, leaf.NotAfter.After(expiry.Add(-2*time.Minute)))
assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Minute))) assert.True(t, leaf.NotAfter.Before(expiry.Add(time.Hour)))
tmplt := a.config.AuthorityConfig.Template tmplt := a.config.AuthorityConfig.Template
assert.Equals(t, leaf.Subject.String(), assert.Equals(t, leaf.Subject.String(),

View file

@ -408,6 +408,7 @@ func TestBootstrapClientServerRotation(t *testing.T) {
server.ServeTLS(listener, "", "") server.ServeTLS(listener, "", "")
}() }()
defer server.Close() defer server.Close()
time.Sleep(1 * time.Second)
// Create bootstrap client // Create bootstrap client
token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7") token = generateBootstrapToken(caURL, "client", "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7")
@ -419,7 +420,6 @@ func TestBootstrapClientServerRotation(t *testing.T) {
// doTest does a request that requires mTLS // doTest does a request that requires mTLS
doTest := func(client *http.Client) error { doTest := func(client *http.Client) error {
time.Sleep(1 * time.Second)
// test with ca // test with ca
resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody) resp, err := client.Post(caURL+"/renew", "application/json", http.NoBody)
if err != nil { if err != nil {

View file

@ -451,9 +451,6 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) {
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
tlsConfig.ClientCAs = certPool tlsConfig.ClientCAs = certPool
// Use server's most preferred ciphersuite
tlsConfig.PreferServerCipherSuites = true
return tlsConfig, nil return tlsConfig, nil
} }

View file

@ -563,6 +563,11 @@ func (c *Client) retryOnError(r *http.Response) bool {
return false return false
} }
// GetCaURL returns the configured CA url.
func (c *Client) GetCaURL() string {
return c.endpoint.String()
}
// GetRootCAs returns the RootCAs certificate pool from the configured // GetRootCAs returns the RootCAs certificate pool from the configured
// transport. // transport.
func (c *Client) GetRootCAs() *x509.CertPool { func (c *Client) GetRootCAs() *x509.CertPool {
@ -723,6 +728,36 @@ retry:
return &sign, nil return &sign, nil
} }
// RenewWithToken performs the renew request to the CA with the given
// authorization token and returns the api.SignResponse struct. This method is
// generally used to renew an expired certificate.
func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
req, err := http.NewRequest("POST", u.String(), http.NoBody)
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error creating request")
}
req.Header.Add("Authorization", "Bearer "+token)
retry:
resp, err := c.client.Do(req)
if err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; client POST %s failed", u)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
retried = true
goto retry
}
return nil, readError(resp.Body)
}
var sign api.SignResponse
if err := readJSON(resp.Body, &sign); err != nil {
return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.RenewWithToken; error reading %s", u)
}
return &sign, nil
}
// Rekey performs the rekey request to the CA and returns the api.SignResponse // Rekey performs the rekey request to the CA and returns the api.SignResponse
// struct. // struct.
func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {

View file

@ -17,13 +17,16 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/x509util"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh"
) )
const ( const (
@ -354,7 +357,7 @@ func TestClient_Sign(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.SignRequest) body := new(api.SignRequest)
if err := api.ReadJSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e) api.WriteError(w, e)
@ -426,7 +429,7 @@ func TestClient_Revoke(t *testing.T) {
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
body := new(api.RevokeRequest) body := new(api.RevokeRequest)
if err := api.ReadJSON(req.Body, body); err != nil { if err := read.JSON(req.Body, body); err != nil {
e, ok := tt.response.(error) e, ok := tt.response.(error)
assert.Fatal(t, ok, "response expected to be error type") assert.Fatal(t, ok, "response expected to be error type")
api.WriteError(w, e) api.WriteError(w, e)
@ -529,6 +532,74 @@ func TestClient_Renew(t *testing.T) {
} }
} }
func TestClient_RenewWithToken(t *testing.T) {
ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)},
CertChainPEM: []api.Certificate{
{Certificate: parseCertificate(certPEM)},
{Certificate: parseCertificate(rootPEM)},
},
}
tests := []struct {
name string
response interface{}
responseCode int
wantErr bool
err error
}{
{"ok", ok, 200, false, nil},
{"unauthorized", errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)},
{"empty request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
{"nil request", errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestPrefix)},
}
srv := httptest.NewServer(nil)
defer srv.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("Authorization") != "Bearer token" {
api.JSONStatus(w, errs.InternalServer("force"), 500)
} else {
api.JSONStatus(w, tt.response, tt.responseCode)
}
})
got, err := c.RenewWithToken("token")
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.RenewWithToken() error = %v, wantErr %v", err, tt.wantErr)
return
}
switch {
case err != nil:
if got != nil {
t.Errorf("Client.RenewWithToken() = %v, want nil", got)
}
sc, ok := err.(errs.StatusCoder)
assert.Fatal(t, ok, "error does not implement StatusCoder interface")
assert.Equals(t, sc.StatusCode(), tt.responseCode)
assert.HasPrefix(t, err.Error(), tt.err.Error())
default:
if !reflect.DeepEqual(got, tt.response) {
t.Errorf("Client.RenewWithToken() = %v, want %v", got, tt.response)
}
}
})
}
}
func TestClient_Rekey(t *testing.T) { func TestClient_Rekey(t *testing.T) {
ok := &api.SignResponse{ ok := &api.SignResponse{
ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)},
@ -1060,3 +1131,28 @@ func TestClient_SSHBastion(t *testing.T) {
}) })
} }
} }
func TestClient_GetCaURL(t *testing.T) {
tests := []struct {
name string
caURL string
want string
}{
{"ok", "https://ca.com", "https://ca.com"},
{"ok no schema", "ca.com", "https://ca.com"},
{"ok with port", "https://ca.com:9000", "https://ca.com:9000"},
{"ok with version", "https://ca.com/1.0", "https://ca.com/1.0"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(tt.caURL, WithTransport(http.DefaultTransport))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
if got := c.GetCaURL(); got != tt.want {
t.Errorf("Client.GetCaURL() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -8,6 +8,7 @@ import (
"net/url" "net/url"
"os" "os"
"reflect" "reflect"
"sort"
"testing" "testing"
) )
@ -196,7 +197,7 @@ func TestLoadClient(t *testing.T) {
switch { switch {
case gotTransport.TLSClientConfig.GetClientCertificate == nil: case gotTransport.TLSClientConfig.GetClientCertificate == nil:
t.Error("LoadClient() transport does not define GetClientCertificate") t.Error("LoadClient() transport does not define GetClientCertificate")
case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !reflect.DeepEqual(gotTransport.TLSClientConfig.RootCAs.Subjects(), wantTransport.TLSClientConfig.RootCAs.Subjects()): case !reflect.DeepEqual(got.CaURL, tt.want.CaURL) || !equalPools(gotTransport.TLSClientConfig.RootCAs, wantTransport.TLSClientConfig.RootCAs):
t.Errorf("LoadClient() = %#v, want %#v", got, tt.want) t.Errorf("LoadClient() = %#v, want %#v", got, tt.want)
default: default:
crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil) crt, err := gotTransport.TLSClientConfig.GetClientCertificate(nil)
@ -238,3 +239,23 @@ func Test_defaultsConfig_Validate(t *testing.T) {
}) })
} }
} }
// nolint:staticcheck,gocritic
func equalPools(a, b *x509.CertPool) bool {
if reflect.DeepEqual(a, b) {
return true
}
subjects := a.Subjects()
sA := make([]string, len(subjects))
for i := range subjects {
sA[i] = string(subjects[i])
}
subjects = b.Subjects()
sB := make([]string, len(subjects))
for i := range subjects {
sB[i] = string(subjects[i])
}
sort.Strings(sA)
sort.Strings(sB)
return reflect.DeepEqual(sA, sB)
}

View file

@ -346,6 +346,8 @@ func TestIdentity_GetCertPool(t *testing.T) {
return return
} }
if got != nil { if got != nil {
// nolint:staticcheck // we don't have a different way to check
// the certificates in the pool.
subjects := got.Subjects() subjects := got.Subjects()
if !reflect.DeepEqual(subjects, tt.wantSubjects) { if !reflect.DeepEqual(subjects, tt.wantSubjects) {
t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects) t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)

2
ca/testdata/ca.json vendored
View file

@ -6,7 +6,7 @@
"password": "password", "password": "password",
"address": "127.0.0.1:0", "address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"], "dnsNames": ["127.0.0.1"],
"logger": {"format": "text"}, "_logger": {"format": "text"},
"tls": { "tls": {
"minVersion": 1.2, "minVersion": 1.2,
"maxVersion": 1.3, "maxVersion": 1.3,

View file

@ -6,7 +6,7 @@
"password": "asdf", "password": "asdf",
"address": "127.0.0.1:0", "address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"], "dnsNames": ["127.0.0.1"],
"logger": {"format": "text"}, "_logger": {"format": "text"},
"tls": { "tls": {
"minVersion": 1.2, "minVersion": 1.2,
"maxVersion": 1.2, "maxVersion": 1.2,

View file

@ -5,7 +5,7 @@
"password": "password", "password": "password",
"address": "127.0.0.1:0", "address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"], "dnsNames": ["127.0.0.1"],
"logger": {"format": "text"}, "_logger": {"format": "text"},
"tls": { "tls": {
"minVersion": 1.2, "minVersion": 1.2,
"maxVersion": 1.2, "maxVersion": 1.2,

View file

@ -5,7 +5,7 @@
"password": "password", "password": "password",
"address": "127.0.0.1:0", "address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"], "dnsNames": ["127.0.0.1"],
"logger": {"format": "text"}, "_logger": {"format": "text"},
"tls": { "tls": {
"minVersion": 1.2, "minVersion": 1.2,
"maxVersion": 1.2, "maxVersion": 1.2,

View file

@ -5,7 +5,7 @@
"password": "asdf", "password": "asdf",
"address": "127.0.0.1:0", "address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"], "dnsNames": ["127.0.0.1"],
"logger": {"format": "text"}, "_logger": {"format": "text"},
"tls": { "tls": {
"minVersion": 1.2, "minVersion": 1.2,
"maxVersion": 1.2, "maxVersion": 1.2,

View file

@ -5,7 +5,7 @@
"password": "asdf", "password": "asdf",
"address": "127.0.0.1:0", "address": "127.0.0.1:0",
"dnsNames": ["127.0.0.1"], "dnsNames": ["127.0.0.1"],
"logger": {"format": "text"}, "_logger": {"format": "text"},
"tls": { "tls": {
"minVersion": 1.2, "minVersion": 1.2,
"maxVersion": 1.2, "maxVersion": 1.2,

View file

@ -95,7 +95,6 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
// Note that with GetClientCertificate tlsConfig.Certificates is not used. // Note that with GetClientCertificate tlsConfig.Certificates is not used.
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
tlsConfig.GetClientCertificate = renewer.GetClientCertificate tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
// Apply options and initialize mutable tls.Config // Apply options and initialize mutable tls.Config
tlsCtx := newTLSOptionCtx(c, tlsConfig, sign) tlsCtx := newTLSOptionCtx(c, tlsConfig, sign)
@ -137,7 +136,6 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
// Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate() // Without tlsConfig.Certificates there's not need to use tlsConfig.BuildNameToCertificate()
tlsConfig.GetCertificate = renewer.GetCertificate tlsConfig.GetCertificate = renewer.GetCertificate
tlsConfig.GetClientCertificate = renewer.GetClientCertificate tlsConfig.GetClientCertificate = renewer.GetClientCertificate
tlsConfig.PreferServerCipherSuites = true
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
// Apply options and initialize mutable tls.Config // Apply options and initialize mutable tls.Config

View file

@ -542,6 +542,7 @@ func TestAddFederationToCAs(t *testing.T) {
} }
} }
// nolint:staticcheck,gocritic
func equalPools(a, b *x509.CertPool) bool { func equalPools(a, b *x509.CertPool) bool {
if reflect.DeepEqual(a, b) { if reflect.DeepEqual(a, b) {
return true return true

6
go.mod
View file

@ -30,12 +30,12 @@ require (
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/slackhq/nebula v1.5.2 github.com/slackhq/nebula v1.5.2
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262
github.com/smallstep/nosql v0.3.10 github.com/smallstep/nosql v0.4.0
github.com/urfave/cli v1.22.4 github.com/urfave/cli v1.22.4
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352
go.step.sm/cli-utils v0.7.0 go.step.sm/cli-utils v0.7.0
go.step.sm/crypto v0.15.0 go.step.sm/crypto v0.15.3
go.step.sm/linkedca v0.10.0 go.step.sm/linkedca v0.11.0
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3
golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d golang.org/x/net v0.0.0-20220114011407-0dd24b26b47d
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect golang.org/x/sys v0.0.0-20220114195835-da31bd327af9 // indirect

8
go.sum
View file

@ -607,8 +607,8 @@ github.com/slackhq/nebula v1.5.2/go.mod h1:xaCM6wqbFk/NRmmUe1bv88fWBm3a1UioXJVIp
github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE= github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc=
github.com/smallstep/nosql v0.3.10 h1:Xs7nueSl250GYb5XdfbzR8w+xPbvF6/oSw6pryY7gJI= github.com/smallstep/nosql v0.4.0 h1:Go3WYwttUuvwqMtFiiU4g7kBIlY+hR0bIZAqVdakQ3M=
github.com/smallstep/nosql v0.3.10/go.mod h1:yKZT5h7cdIVm6wEKM9+jN5dgK80Hljpuy8HNsnI7Gzo= github.com/smallstep/nosql v0.4.0/go.mod h1:yKZT5h7cdIVm6wEKM9+jN5dgK80Hljpuy8HNsnI7Gzo=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
@ -684,8 +684,8 @@ go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqe
go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds= go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds=
go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E=
go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0=
go.step.sm/crypto v0.15.0 h1:VioBln+x3+RoejgeBhvxkLGVYdWRy6PFiAaUUN29/E0= go.step.sm/crypto v0.15.3 h1:f3GMl+aCydt294BZRjTYwpaXRqwwndvoTY2NLN4wu10=
go.step.sm/crypto v0.15.0/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g= go.step.sm/crypto v0.15.3/go.mod h1:3G0yQr5lQqfEG0CMYz8apC/qMtjLRQlzflL2AxkcN+g=
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=

View file

@ -10,14 +10,14 @@ import (
"strings" "strings"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/api" microscep "github.com/micromdm/scep/v2/scep"
"github.com/smallstep/certificates/authority/provisioner" "github.com/pkg/errors"
"github.com/smallstep/certificates/scep"
"go.mozilla.org/pkcs7" "go.mozilla.org/pkcs7"
"github.com/pkg/errors" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/log"
microscep "github.com/micromdm/scep/v2/scep" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/scep"
) )
const ( const (
@ -66,7 +66,9 @@ func New(scepAuth scep.Interface) api.RouterHandler {
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) { func (h *Handler) Route(r api.Router) {
getLink := h.Auth.GetLinkExplicit getLink := h.Auth.GetLinkExplicit
r.MethodFunc(http.MethodGet, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Get))
r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get)) r.MethodFunc(http.MethodGet, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Get))
r.MethodFunc(http.MethodPost, getLink("{provisionerName}/*", false, nil), h.lookupProvisioner(h.Post))
r.MethodFunc(http.MethodPost, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Post)) r.MethodFunc(http.MethodPost, getLink("{provisionerName}", false, nil), h.lookupProvisioner(h.Post))
} }
@ -335,7 +337,7 @@ func formatCapabilities(caps []string) []byte {
func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) { func writeSCEPResponse(w http.ResponseWriter, response SCEPResponse) {
if response.Error != nil { if response.Error != nil {
api.LogError(w, response.Error) log.Error(w, response.Error)
} }
if response.Certificate != nil { if response.Certificate != nil {