Make provisioner more configurable.

The intention of this change is to make it usable from cert-manager.
This commit is contained in:
Mariano Cano 2019-06-17 19:01:04 -07:00
parent 4075407d63
commit 01b6aebbf7
5 changed files with 249 additions and 88 deletions

4
Gopkg.lock generated
View file

@ -338,7 +338,7 @@
[[projects]]
branch = "master"
digest = "1:dad8c405e442687cf36b3d236290554168c78c0b71c0f2cef35635933a225fe3"
digest = "1:5be7f46f64720a346290a9aeb734521523ebe4bbc0180c1d9f4355fac578e3c3"
name = "github.com/smallstep/cli"
packages = [
"command",
@ -359,7 +359,7 @@
"utils",
]
pruneopts = "UT"
revision = "8429a2f6f5d6f097b843322a9a8e80d6fd087258"
revision = "ea92c904da59b3892f53435d01902156ec5e171d"
[[projects]]
branch = "master"

View file

@ -16,12 +16,15 @@ import (
"io/ioutil"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/pkg/errors"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/cli/config"
"github.com/smallstep/cli/crypto/x509util"
"gopkg.in/square/go-jose.v2/jwt"
)
@ -33,6 +36,7 @@ type clientOptions struct {
transport http.RoundTripper
rootSHA256 string
rootFilename string
rootBundle []byte
}
func (o *clientOptions) apply(opts []ClientOption) (err error) {
@ -47,7 +51,7 @@ func (o *clientOptions) apply(opts []ClientOption) (err error) {
// checkTransport checks if other ways to set up a transport have been provided.
// If they have it returns an error.
func (o *clientOptions) checkTransport() error {
if o.transport != nil || o.rootFilename != "" || o.rootSHA256 != "" {
if o.transport != nil || o.rootFilename != "" || o.rootSHA256 != "" || o.rootBundle != nil {
return errors.New("multiple transport methods have been configured")
}
return nil
@ -68,14 +72,27 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
return nil, err
}
}
if o.rootBundle != nil {
if tr, err = getTransportFromCABundle(o.rootBundle); err != nil {
return nil, err
}
}
// As the last option attempt to load the default root ca
if tr == nil {
rootFile := getRootCAPath()
if _, err := os.Stat(rootFile); err == nil {
if tr, err = getTransportFromFile(rootFile); err != nil {
return nil, err
}
return tr, nil
}
return nil, errors.New("a transport, a root cert, or a root sha256 must be used")
}
return tr, nil
}
// WithTransport adds a custom transport to the Client. If the transport is
// given is given it will have preference over WithRootFile and WithRootSHA256.
// WithTransport adds a custom transport to the Client. It will fail if a
// previous option to create the transport has been configured.
func WithTransport(tr http.RoundTripper) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
@ -86,9 +103,8 @@ func WithTransport(tr http.RoundTripper) ClientOption {
}
}
// WithRootFile will create the transport using the given root certificate. If
// the root file is given it will have preference over WithRootSHA256, but less
// preference than WithTransport.
// WithRootFile will create the transport using the given root certificate. It
// will fail if a previous option to create the transport has been configured.
func WithRootFile(filename string) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
@ -99,8 +115,9 @@ func WithRootFile(filename string) ClientOption {
}
}
// WithRootSHA256 will create the transport using an insecure client to retrieve the
// root certificate. It has less preference than WithTransport and WithRootFile.
// WithRootSHA256 will create the transport using an insecure client to retrieve
// the root certificate using its fingerprint. It will fail if a previous option
// to create the transport has been configured.
func WithRootSHA256(sum string) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
@ -111,6 +128,18 @@ func WithRootSHA256(sum string) ClientOption {
}
}
// WithCABundle will create the transport using the given root certificates. It
// will fail if a previous option to create the transport has been configured.
func WithCABundle(bundle []byte) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
return err
}
o.rootBundle = bundle
return nil
}
}
func getTransportFromFile(filename string) (http.RoundTripper, error) {
data, err := ioutil.ReadFile(filename)
if err != nil {
@ -146,6 +175,18 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
})
}
func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(bundle) {
return nil, errors.New("error parsing ca bundle: no certificates found")
}
return getDefaultTransport(&tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}
// parseEndpoint parses and validates the given endpoint. It supports general
// URLs like https://ca.smallstep.com[:port][/path], and incomplete URLs like
// ca.smallstep.com[:port][/path].
@ -464,6 +505,25 @@ func (c *Client) Federation() (*api.FederationResponse, error) {
return &federation, nil
}
// RootFingerprint is a helper method that returns the current root fingerprint.
// It does an health connection and gets the fingerprint from the TLS verified
// chains.
func (c *Client) RootFingerprint() (string, error) {
u := c.endpoint.ResolveReference(&url.URL{Path: "/health"})
resp, err := c.client.Get(u.String())
if err != nil {
return "", errors.Wrapf(err, "client GET %s failed", u)
}
if resp.TLS == nil || len(resp.TLS.VerifiedChains) == 0 {
return "", errors.New("missing verified chains")
}
lastChain := resp.TLS.VerifiedChains[len(resp.TLS.VerifiedChains)-1]
if len(lastChain) == 0 {
return "", errors.New("missing verified chains")
}
return x509util.Fingerprint(lastChain[len(lastChain)-1]), nil
}
// CreateSignRequest is a helper function that given an x509 OTT returns a
// simple but secure sign request as well as the private key used.
func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) {
@ -522,6 +582,12 @@ func getInsecureClient() *http.Client {
}
}
// getRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable.
func getRootCAPath() string {
return filepath.Join(config.StepPath(), "certs", "root_ca.crt")
}
func readJSON(r io.ReadCloser, v interface{}) error {
defer r.Close()
return json.NewDecoder(r).Decode(v)

View file

@ -13,8 +13,10 @@ import (
"testing"
"time"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/x509util"
)
const (
@ -746,3 +748,66 @@ func Test_parseEndpoint(t *testing.T) {
})
}
}
func TestClient_RootFingerprint(t *testing.T) {
ok := &api.HealthResponse{Status: "ok"}
nok := api.InternalServerError(fmt.Errorf("Internal Server Error"))
httpsServer := httptest.NewTLSServer(nil)
defer httpsServer.Close()
httpsServerFingerprint := x509util.Fingerprint(httpsServer.Certificate())
httpServer := httptest.NewServer(nil)
defer httpServer.Close()
tests := []struct {
name string
server *httptest.Server
response interface{}
responseCode int
want string
wantErr bool
}{
{"ok", httpsServer, ok, 200, httpsServerFingerprint, false},
{"ok with error", httpsServer, nok, 500, httpsServerFingerprint, false},
{"fail", httpServer, ok, 200, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tr := tt.server.Client().Transport
c, err := NewClient(tt.server.URL, WithTransport(tr))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode)
api.JSON(w, tt.response)
})
got, err := c.RootFingerprint()
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.RootFingerprint() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Client.RootFingerprint() = %v, want %v", got, tt.want)
}
})
}
}
func TestClient_RootFingerprintWithServer(t *testing.T) {
srv := startCABootstrapServer()
defer srv.Close()
client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt"))
assert.FatalError(t, err)
fp, err := client.RootFingerprint()
assert.FatalError(t, err)
assert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp)
}

View file

@ -2,30 +2,27 @@ package ca
import (
"encoding/json"
"fmt"
"path/filepath"
"net/url"
"time"
"github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/config"
"github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose"
"github.com/smallstep/cli/token"
"github.com/smallstep/cli/token/provision"
)
const (
tokenLifetime = 5 * time.Minute
)
const tokenLifetime = 5 * time.Minute
// Provisioner is an authorized entity that can sign tokens necessary for
// signature requests.
type Provisioner struct {
*Client
name string
kid string
caURL string
caRoot string
audience string
fingerprint string
jwk *jose.JSONWebKey
tokenLifetime time.Duration
}
@ -34,26 +31,36 @@ type Provisioner struct {
// provisioner. The key identified by `kid` will be used if specified. If `kid`
// is the empty string we'll use the first key for the named provisioner that
// decrypts using `password`.
func NewProvisioner(name, kid, caURL, caRoot string, password []byte) (*Provisioner, error) {
var jwk *jose.JSONWebKey
var err error
switch {
case name == "":
return nil, errors.New("provisioner name cannot be empty")
case kid == "":
jwk, err = loadProvisionerJWKByName(name, caURL, caRoot, password)
default:
jwk, err = loadProvisionerJWKByKid(kid, caURL, caRoot, password)
}
func NewProvisioner(name, kid, caURL string, password []byte, opts ...ClientOption) (*Provisioner, error) {
client, err := NewClient(caURL, opts...)
if err != nil {
return nil, err
}
// Get the fingerprint of the current connection
fp, err := client.RootFingerprint()
if err != nil {
return nil, err
}
var jwk *jose.JSONWebKey
switch {
case name == "":
return nil, errors.New("provisioner name cannot be empty")
case kid == "":
jwk, err = loadProvisionerJWKByName(client, name, password)
default:
jwk, err = loadProvisionerJWKByKid(client, kid, password)
}
if err != nil {
return nil, err
}
return &Provisioner{
Client: client,
name: name,
kid: jwk.KeyID,
caURL: caURL,
caRoot: caRoot,
audience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/sign"}).String(),
fingerprint: fp,
jwk: jwk,
tokenLifetime: tokenLifetime,
}, nil
@ -69,8 +76,17 @@ func (p *Provisioner) Kid() string {
return p.kid
}
// SetFingerprint overwrites the default fingerprint used.
func (p *Provisioner) SetFingerprint(sum string) {
p.fingerprint = sum
}
// Token generates a bootstrap token for a subject.
func (p *Provisioner) Token(subject string) (string, error) {
func (p *Provisioner) Token(subject string, sans ...string) (string, error) {
if len(sans) == 0 {
sans = []string{subject}
}
// A random jwt id will be used to identify duplicated tokens
jwtID, err := randutil.Hex(64) // 256 bits
if err != nil {
@ -79,16 +95,17 @@ func (p *Provisioner) Token(subject string) (string, error) {
notBefore := time.Now()
notAfter := notBefore.Add(tokenLifetime)
signURL := fmt.Sprintf("%v/1.0/sign", p.caURL)
tokOptions := []token.Options{
token.WithJWTID(jwtID),
token.WithKid(p.kid),
token.WithIssuer(p.name),
token.WithAudience(signURL),
token.WithAudience(p.audience),
token.WithValidity(notBefore, notAfter),
token.WithRootCA(p.caRoot),
token.WithSANS([]string{subject}),
token.WithSANS(sans),
}
if p.fingerprint != "" {
tokOptions = append(tokOptions, token.WithSHA(p.fingerprint))
}
tok, err := provision.New(subject, tokOptions...)
@ -117,8 +134,8 @@ func decryptProvisionerJWK(encryptedKey string, password []byte) (*jose.JSONWebK
// loadProvisionerJWKByKid retrieves a provisioner key from the CA by key ID and
// decrypts it using the specified password.
func loadProvisionerJWKByKid(kid, caURL, caRoot string, password []byte) (*jose.JSONWebKey, error) {
encrypted, err := getProvisionerKey(caURL, caRoot, kid)
func loadProvisionerJWKByKid(client *Client, kid string, password []byte) (*jose.JSONWebKey, error) {
encrypted, err := getProvisionerKey(client, kid)
if err != nil {
return nil, err
}
@ -129,8 +146,8 @@ func loadProvisionerJWKByKid(kid, caURL, caRoot string, password []byte) (*jose.
// loadProvisionerJWKByName retrieves the list of provisioners and encrypted key then
// returns the key of the first provisioner with a matching name that can be successfully
// decrypted with the specified password.
func loadProvisionerJWKByName(name, caURL, caRoot string, password []byte) (key *jose.JSONWebKey, err error) {
provisioners, err := getProvisioners(caURL, caRoot)
func loadProvisionerJWKByName(client *Client, name string, password []byte) (key *jose.JSONWebKey, err error) {
provisioners, err := getProvisioners(client)
if err != nil {
err = errors.Wrap(err, "error getting the provisioners")
return
@ -149,22 +166,9 @@ func loadProvisionerJWKByName(name, caURL, caRoot string, password []byte) (key
return nil, errors.Errorf("provisioner '%s' not found (or your password is wrong)", name)
}
// getRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable.
func getRootCAPath() string {
return filepath.Join(config.StepPath(), "certs", "root_ca.crt")
}
// getProvisioners returns the map of provisioners on the given CA.
func getProvisioners(caURL, rootFile string) (provisioner.List, error) {
if len(rootFile) == 0 {
rootFile = getRootCAPath()
}
client, err := NewClient(caURL, WithRootFile(rootFile))
if err != nil {
return nil, err
}
cursor := ""
// getProvisioners returns the list of provisioners using the configured client.
func getProvisioners(client *Client) (provisioner.List, error) {
var cursor string
var provisioners provisioner.List
for {
resp, err := client.Provisioners(WithProvisionerCursor(cursor), WithProvisionerLimit(100))
@ -180,14 +184,7 @@ func getProvisioners(caURL, rootFile string) (provisioner.List, error) {
}
// getProvisionerKey returns the encrypted provisioner key for the given kid.
func getProvisionerKey(caURL, rootFile, kid string) (string, error) {
if len(rootFile) == 0 {
rootFile = getRootCAPath()
}
client, err := NewClient(caURL, WithRootFile(rootFile))
if err != nil {
return "", err
}
func getProvisionerKey(client *Client, kid string) (string, error) {
resp, err := client.ProvisionerKey(kid)
if err != nil {
return "", err

View file

@ -1,23 +1,38 @@
package ca
import (
"net/url"
"reflect"
"testing"
"time"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/x509util"
"github.com/smallstep/cli/jose"
)
func getTestProvisioner(t *testing.T, url string) *Provisioner {
func getTestProvisioner(t *testing.T, caURL string) *Provisioner {
jwk, err := jose.ParseKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password")))
if err != nil {
t.Fatal(err)
}
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
client, err := NewClient(caURL)
if err != nil {
t.Fatal(err)
}
return &Provisioner{
Client: client,
name: "mariano",
kid: "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
caURL: url,
caRoot: "testdata/secrets/root_ca.crt",
audience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/sign"}).String(),
fingerprint: x509util.Fingerprint(cert),
jwk: jwk,
tokenLifetime: 5 * time.Minute,
}
@ -32,8 +47,8 @@ func TestNewProvisioner(t *testing.T) {
name string
kid string
caURL string
caRoot string
password []byte
caRoot string
}
tests := []struct {
name string
@ -41,21 +56,27 @@ func TestNewProvisioner(t *testing.T) {
want *Provisioner
wantErr bool
}{
{"ok", args{want.name, want.kid, want.caURL, want.caRoot, []byte("password")}, want, false},
{"ok-by-name", args{want.name, "", want.caURL, want.caRoot, []byte("password")}, want, false},
{"fail-bad-kid", args{want.name, "bad-kid", want.caURL, want.caRoot, []byte("password")}, nil, true},
{"fail-empty-name", args{"", want.kid, want.caURL, want.caRoot, []byte("password")}, nil, true},
{"fail-bad-name", args{"bad-name", "", want.caURL, want.caRoot, []byte("password")}, nil, true},
{"fail-by-password", args{want.name, want.kid, want.caURL, want.caRoot, []byte("bad-password")}, nil, true},
{"fail-by-password-no-kid", args{want.name, "", want.caURL, want.caRoot, []byte("bad-password")}, nil, true},
{"ok", args{want.name, want.kid, ca.URL, []byte("password"), "testdata/secrets/root_ca.crt"}, want, false},
{"ok-by-name", args{want.name, "", ca.URL, []byte("password"), "testdata/secrets/root_ca.crt"}, want, false},
{"fail-bad-kid", args{want.name, "bad-kid", ca.URL, []byte("password"), "testdata/secrets/root_ca.crt"}, nil, true},
{"fail-empty-name", args{"", want.kid, ca.URL, []byte("password"), "testdata/secrets/root_ca.crt"}, nil, true},
{"fail-bad-name", args{"bad-name", "", ca.URL, []byte("password"), "testdata/secrets/root_ca.crt"}, nil, true},
{"fail-by-password", args{want.name, want.kid, ca.URL, []byte("bad-password"), "testdata/secrets/root_ca.crt"}, nil, true},
{"fail-by-password-no-kid", args{want.name, "", ca.URL, []byte("bad-password"), "testdata/secrets/root_ca.crt"}, nil, true},
{"fail-bad-certificate", args{want.name, want.kid, ca.URL, []byte("password"), "testdata/secrets/federatec_ca.crt"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewProvisioner(tt.args.name, tt.args.kid, tt.args.caURL, tt.args.caRoot, tt.args.password)
got, err := NewProvisioner(tt.args.name, tt.args.kid, tt.args.caURL, tt.args.password, WithRootFile(tt.args.caRoot))
if (err != nil) != tt.wantErr {
t.Errorf("NewProvisioner() error = %v, wantErr %v", err, tt.wantErr)
return
}
// Client won't match.
// Make sure it does.
if got != nil {
got.Client = want.Client
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewProvisioner() = %v, want %v", got, tt.want)
}
@ -80,13 +101,13 @@ func TestProvisioner_Token(t *testing.T) {
type fields struct {
name string
kid string
caURL string
caRoot string
fingerprint string
jwk *jose.JSONWebKey
tokenLifetime time.Duration
}
type args struct {
subject string
sans []string
}
tests := []struct {
name string
@ -94,21 +115,23 @@ func TestProvisioner_Token(t *testing.T) {
args args
wantErr bool
}{
{"ok", fields{p.name, p.kid, p.caURL, p.caRoot, p.jwk, p.tokenLifetime}, args{"subject"}, false},
{"fail-no-subject", fields{p.name, p.kid, p.caURL, p.caRoot, p.jwk, p.tokenLifetime}, args{""}, true},
{"fail-no-key", fields{p.name, p.kid, p.caURL, p.caRoot, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject"}, true},
{"ok", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", nil}, false},
{"ok-with-san", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com"}}, false},
{"ok-with-sans", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com", "127.0.0.1"}}, false},
{"fail-no-subject", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"", []string{"foo.smallstep.com"}}, true},
{"fail-no-key", fields{p.name, p.kid, sha, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject", nil}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Provisioner{
name: tt.fields.name,
kid: tt.fields.kid,
caURL: tt.fields.caURL,
caRoot: tt.fields.caRoot,
audience: "https://127.0.0.1:9000/1.0/sign",
fingerprint: tt.fields.fingerprint,
jwk: tt.fields.jwk,
tokenLifetime: tt.fields.tokenLifetime,
}
got, err := p.Token(tt.args.subject)
got, err := p.Token(tt.args.subject, tt.args.sans...)
if (err != nil) != tt.wantErr {
t.Errorf("Provisioner.Token() error = %v, wantErr %v", err, tt.wantErr)
return
@ -126,7 +149,7 @@ func TestProvisioner_Token(t *testing.T) {
return
}
if err := claims.ValidateWithLeeway(jose.Expected{
Audience: []string{tt.fields.caURL + "/1.0/sign"},
Audience: []string{"https://127.0.0.1:9000/1.0/sign"},
Issuer: tt.fields.name,
Subject: tt.args.subject,
Time: time.Now().UTC(),
@ -146,8 +169,18 @@ func TestProvisioner_Token(t *testing.T) {
if v, ok := allClaims["sha"].(string); !ok || v != sha {
t.Errorf("Claim sha = %s, want %s", v, sha)
}
if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) {
t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject})
if len(tt.args.sans) == 0 {
if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) {
t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject})
}
} else {
want := []interface{}{}
for _, s := range tt.args.sans {
want = append(want, s)
}
if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, want) {
t.Errorf("Claim sans = %s, want %s", v, want)
}
}
if v, ok := allClaims["jti"].(string); !ok || v == "" {
t.Errorf("Claim jti = %s, want not blank", v)