Merge pull request #80 from smallstep/cert-manager

Improve ca.Provisioner
This commit is contained in:
Mariano Cano 2019-06-24 10:59:00 -07:00 committed by GitHub
commit f12e2dedd5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 262 additions and 92 deletions

4
Gopkg.lock generated
View file

@ -338,7 +338,7 @@
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:dad8c405e442687cf36b3d236290554168c78c0b71c0f2cef35635933a225fe3" digest = "1:8eb842c27bca9dae16d77baeba7cf612135033da381faf833bb8c11c29a751c7"
name = "github.com/smallstep/cli" name = "github.com/smallstep/cli"
packages = [ packages = [
"command", "command",
@ -359,7 +359,7 @@
"utils", "utils",
] ]
pruneopts = "UT" pruneopts = "UT"
revision = "8429a2f6f5d6f097b843322a9a8e80d6fd087258" revision = "98635d188cade54451e3997b530716297ce7fc00"
[[projects]] [[projects]]
branch = "master" branch = "master"

View file

@ -16,12 +16,15 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/cli/config"
"github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/crypto/x509util"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
) )
@ -33,6 +36,7 @@ type clientOptions struct {
transport http.RoundTripper transport http.RoundTripper
rootSHA256 string rootSHA256 string
rootFilename string rootFilename string
rootBundle []byte
} }
func (o *clientOptions) apply(opts []ClientOption) (err error) { func (o *clientOptions) apply(opts []ClientOption) (err error) {
@ -47,7 +51,7 @@ func (o *clientOptions) apply(opts []ClientOption) (err error) {
// checkTransport checks if other ways to set up a transport have been provided. // checkTransport checks if other ways to set up a transport have been provided.
// If they have it returns an error. // If they have it returns an error.
func (o *clientOptions) checkTransport() error { func (o *clientOptions) checkTransport() error {
if o.transport != nil || o.rootFilename != "" || o.rootSHA256 != "" { if o.transport != nil || o.rootFilename != "" || o.rootSHA256 != "" || o.rootBundle != nil {
return errors.New("multiple transport methods have been configured") return errors.New("multiple transport methods have been configured")
} }
return nil return nil
@ -68,14 +72,27 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
return nil, err return nil, err
} }
} }
if o.rootBundle != nil {
if tr, err = getTransportFromCABundle(o.rootBundle); err != nil {
return nil, err
}
}
// As the last option attempt to load the default root ca
if tr == nil { if tr == nil {
rootFile := getRootCAPath()
if _, err := os.Stat(rootFile); err == nil {
if tr, err = getTransportFromFile(rootFile); err != nil {
return nil, err
}
return tr, nil
}
return nil, errors.New("a transport, a root cert, or a root sha256 must be used") return nil, errors.New("a transport, a root cert, or a root sha256 must be used")
} }
return tr, nil return tr, nil
} }
// WithTransport adds a custom transport to the Client. If the transport is // WithTransport adds a custom transport to the Client. It will fail if a
// given is given it will have preference over WithRootFile and WithRootSHA256. // previous option to create the transport has been configured.
func WithTransport(tr http.RoundTripper) ClientOption { func WithTransport(tr http.RoundTripper) ClientOption {
return func(o *clientOptions) error { return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil { if err := o.checkTransport(); err != nil {
@ -86,9 +103,8 @@ func WithTransport(tr http.RoundTripper) ClientOption {
} }
} }
// WithRootFile will create the transport using the given root certificate. If // WithRootFile will create the transport using the given root certificate. It
// the root file is given it will have preference over WithRootSHA256, but less // will fail if a previous option to create the transport has been configured.
// preference than WithTransport.
func WithRootFile(filename string) ClientOption { func WithRootFile(filename string) ClientOption {
return func(o *clientOptions) error { return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil { if err := o.checkTransport(); err != nil {
@ -99,8 +115,9 @@ func WithRootFile(filename string) ClientOption {
} }
} }
// WithRootSHA256 will create the transport using an insecure client to retrieve the // WithRootSHA256 will create the transport using an insecure client to retrieve
// root certificate. It has less preference than WithTransport and WithRootFile. // the root certificate using its fingerprint. It will fail if a previous option
// to create the transport has been configured.
func WithRootSHA256(sum string) ClientOption { func WithRootSHA256(sum string) ClientOption {
return func(o *clientOptions) error { return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil { if err := o.checkTransport(); err != nil {
@ -111,6 +128,18 @@ func WithRootSHA256(sum string) ClientOption {
} }
} }
// WithCABundle will create the transport using the given root certificates. It
// will fail if a previous option to create the transport has been configured.
func WithCABundle(bundle []byte) ClientOption {
return func(o *clientOptions) error {
if err := o.checkTransport(); err != nil {
return err
}
o.rootBundle = bundle
return nil
}
}
func getTransportFromFile(filename string) (http.RoundTripper, error) { func getTransportFromFile(filename string) (http.RoundTripper, error) {
data, err := ioutil.ReadFile(filename) data, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
@ -146,6 +175,18 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) {
}) })
} }
func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) {
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(bundle) {
return nil, errors.New("error parsing ca bundle: no certificates found")
}
return getDefaultTransport(&tls.Config{
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
RootCAs: pool,
})
}
// parseEndpoint parses and validates the given endpoint. It supports general // parseEndpoint parses and validates the given endpoint. It supports general
// URLs like https://ca.smallstep.com[:port][/path], and incomplete URLs like // URLs like https://ca.smallstep.com[:port][/path], and incomplete URLs like
// ca.smallstep.com[:port][/path]. // ca.smallstep.com[:port][/path].
@ -464,6 +505,25 @@ func (c *Client) Federation() (*api.FederationResponse, error) {
return &federation, nil return &federation, nil
} }
// RootFingerprint is a helper method that returns the current root fingerprint.
// It does an health connection and gets the fingerprint from the TLS verified
// chains.
func (c *Client) RootFingerprint() (string, error) {
u := c.endpoint.ResolveReference(&url.URL{Path: "/health"})
resp, err := c.client.Get(u.String())
if err != nil {
return "", errors.Wrapf(err, "client GET %s failed", u)
}
if resp.TLS == nil || len(resp.TLS.VerifiedChains) == 0 {
return "", errors.New("missing verified chains")
}
lastChain := resp.TLS.VerifiedChains[len(resp.TLS.VerifiedChains)-1]
if len(lastChain) == 0 {
return "", errors.New("missing verified chains")
}
return x509util.Fingerprint(lastChain[len(lastChain)-1]), nil
}
// CreateSignRequest is a helper function that given an x509 OTT returns a // CreateSignRequest is a helper function that given an x509 OTT returns a
// simple but secure sign request as well as the private key used. // simple but secure sign request as well as the private key used.
func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) { func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) {
@ -522,6 +582,12 @@ func getInsecureClient() *http.Client {
} }
} }
// getRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable.
func getRootCAPath() string {
return filepath.Join(config.StepPath(), "certs", "root_ca.crt")
}
func readJSON(r io.ReadCloser, v interface{}) error { func readJSON(r io.ReadCloser, v interface{}) error {
defer r.Close() defer r.Close()
return json.NewDecoder(r).Decode(v) return json.NewDecoder(r).Decode(v)

View file

@ -13,8 +13,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/crypto/x509util"
) )
const ( const (
@ -746,3 +748,66 @@ func Test_parseEndpoint(t *testing.T) {
}) })
} }
} }
func TestClient_RootFingerprint(t *testing.T) {
ok := &api.HealthResponse{Status: "ok"}
nok := api.InternalServerError(fmt.Errorf("Internal Server Error"))
httpsServer := httptest.NewTLSServer(nil)
defer httpsServer.Close()
httpsServerFingerprint := x509util.Fingerprint(httpsServer.Certificate())
httpServer := httptest.NewServer(nil)
defer httpServer.Close()
tests := []struct {
name string
server *httptest.Server
response interface{}
responseCode int
want string
wantErr bool
}{
{"ok", httpsServer, ok, 200, httpsServerFingerprint, false},
{"ok with error", httpsServer, nok, 500, httpsServerFingerprint, false},
{"fail", httpServer, ok, 200, "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tr := tt.server.Client().Transport
c, err := NewClient(tt.server.URL, WithTransport(tr))
if err != nil {
t.Errorf("NewClient() error = %v", err)
return
}
tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(tt.responseCode)
api.JSON(w, tt.response)
})
got, err := c.RootFingerprint()
if (err != nil) != tt.wantErr {
fmt.Printf("%+v", err)
t.Errorf("Client.RootFingerprint() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Client.RootFingerprint() = %v, want %v", got, tt.want)
}
})
}
}
func TestClient_RootFingerprintWithServer(t *testing.T) {
srv := startCABootstrapServer()
defer srv.Close()
client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt"))
assert.FatalError(t, err)
fp, err := client.RootFingerprint()
assert.FatalError(t, err)
assert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp)
}

View file

@ -2,30 +2,27 @@ package ca
import ( import (
"encoding/json" "encoding/json"
"fmt" "net/url"
"path/filepath"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/cli/config"
"github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/crypto/randutil"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
"github.com/smallstep/cli/token" "github.com/smallstep/cli/token"
"github.com/smallstep/cli/token/provision" "github.com/smallstep/cli/token/provision"
) )
const ( const tokenLifetime = 5 * time.Minute
tokenLifetime = 5 * time.Minute
)
// Provisioner is an authorized entity that can sign tokens necessary for // Provisioner is an authorized entity that can sign tokens necessary for
// signature requests. // signature requests.
type Provisioner struct { type Provisioner struct {
*Client
name string name string
kid string kid string
caURL string audience string
caRoot string fingerprint string
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
tokenLifetime time.Duration tokenLifetime time.Duration
} }
@ -34,26 +31,36 @@ type Provisioner struct {
// provisioner. The key identified by `kid` will be used if specified. If `kid` // provisioner. The key identified by `kid` will be used if specified. If `kid`
// is the empty string we'll use the first key for the named provisioner that // is the empty string we'll use the first key for the named provisioner that
// decrypts using `password`. // decrypts using `password`.
func NewProvisioner(name, kid, caURL, caRoot string, password []byte) (*Provisioner, error) { func NewProvisioner(name, kid, caURL string, password []byte, opts ...ClientOption) (*Provisioner, error) {
var jwk *jose.JSONWebKey client, err := NewClient(caURL, opts...)
var err error
switch {
case name == "":
return nil, errors.New("provisioner name cannot be empty")
case kid == "":
jwk, err = loadProvisionerJWKByName(name, caURL, caRoot, password)
default:
jwk, err = loadProvisionerJWKByKid(kid, caURL, caRoot, password)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Get the fingerprint of the current connection
fp, err := client.RootFingerprint()
if err != nil {
return nil, err
}
var jwk *jose.JSONWebKey
switch {
case name == "":
return nil, errors.New("provisioner name cannot be empty")
case kid == "":
jwk, err = loadProvisionerJWKByName(client, name, password)
default:
jwk, err = loadProvisionerJWKByKid(client, kid, password)
}
if err != nil {
return nil, err
}
return &Provisioner{ return &Provisioner{
Client: client,
name: name, name: name,
kid: jwk.KeyID, kid: jwk.KeyID,
caURL: caURL, audience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/sign"}).String(),
caRoot: caRoot, fingerprint: fp,
jwk: jwk, jwk: jwk,
tokenLifetime: tokenLifetime, tokenLifetime: tokenLifetime,
}, nil }, nil
@ -69,8 +76,17 @@ func (p *Provisioner) Kid() string {
return p.kid return p.kid
} }
// SetFingerprint overwrites the default fingerprint used.
func (p *Provisioner) SetFingerprint(sum string) {
p.fingerprint = sum
}
// Token generates a bootstrap token for a subject. // Token generates a bootstrap token for a subject.
func (p *Provisioner) Token(subject string) (string, error) { func (p *Provisioner) Token(subject string, sans ...string) (string, error) {
if len(sans) == 0 {
sans = []string{subject}
}
// A random jwt id will be used to identify duplicated tokens // A random jwt id will be used to identify duplicated tokens
jwtID, err := randutil.Hex(64) // 256 bits jwtID, err := randutil.Hex(64) // 256 bits
if err != nil { if err != nil {
@ -79,16 +95,17 @@ func (p *Provisioner) Token(subject string) (string, error) {
notBefore := time.Now() notBefore := time.Now()
notAfter := notBefore.Add(tokenLifetime) notAfter := notBefore.Add(tokenLifetime)
signURL := fmt.Sprintf("%v/1.0/sign", p.caURL)
tokOptions := []token.Options{ tokOptions := []token.Options{
token.WithJWTID(jwtID), token.WithJWTID(jwtID),
token.WithKid(p.kid), token.WithKid(p.kid),
token.WithIssuer(p.name), token.WithIssuer(p.name),
token.WithAudience(signURL), token.WithAudience(p.audience),
token.WithValidity(notBefore, notAfter), token.WithValidity(notBefore, notAfter),
token.WithRootCA(p.caRoot), token.WithSANS(sans),
token.WithSANS([]string{subject}), }
if p.fingerprint != "" {
tokOptions = append(tokOptions, token.WithSHA(p.fingerprint))
} }
tok, err := provision.New(subject, tokOptions...) tok, err := provision.New(subject, tokOptions...)
@ -117,8 +134,8 @@ func decryptProvisionerJWK(encryptedKey string, password []byte) (*jose.JSONWebK
// loadProvisionerJWKByKid retrieves a provisioner key from the CA by key ID and // loadProvisionerJWKByKid retrieves a provisioner key from the CA by key ID and
// decrypts it using the specified password. // decrypts it using the specified password.
func loadProvisionerJWKByKid(kid, caURL, caRoot string, password []byte) (*jose.JSONWebKey, error) { func loadProvisionerJWKByKid(client *Client, kid string, password []byte) (*jose.JSONWebKey, error) {
encrypted, err := getProvisionerKey(caURL, caRoot, kid) encrypted, err := getProvisionerKey(client, kid)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -129,8 +146,8 @@ func loadProvisionerJWKByKid(kid, caURL, caRoot string, password []byte) (*jose.
// loadProvisionerJWKByName retrieves the list of provisioners and encrypted key then // loadProvisionerJWKByName retrieves the list of provisioners and encrypted key then
// returns the key of the first provisioner with a matching name that can be successfully // returns the key of the first provisioner with a matching name that can be successfully
// decrypted with the specified password. // decrypted with the specified password.
func loadProvisionerJWKByName(name, caURL, caRoot string, password []byte) (key *jose.JSONWebKey, err error) { func loadProvisionerJWKByName(client *Client, name string, password []byte) (key *jose.JSONWebKey, err error) {
provisioners, err := getProvisioners(caURL, caRoot) provisioners, err := getProvisioners(client)
if err != nil { if err != nil {
err = errors.Wrap(err, "error getting the provisioners") err = errors.Wrap(err, "error getting the provisioners")
return return
@ -149,22 +166,9 @@ func loadProvisionerJWKByName(name, caURL, caRoot string, password []byte) (key
return nil, errors.Errorf("provisioner '%s' not found (or your password is wrong)", name) return nil, errors.Errorf("provisioner '%s' not found (or your password is wrong)", name)
} }
// getRootCAPath returns the path where the root CA is stored based on the // getProvisioners returns the list of provisioners using the configured client.
// STEPPATH environment variable. func getProvisioners(client *Client) (provisioner.List, error) {
func getRootCAPath() string { var cursor string
return filepath.Join(config.StepPath(), "certs", "root_ca.crt")
}
// getProvisioners returns the map of provisioners on the given CA.
func getProvisioners(caURL, rootFile string) (provisioner.List, error) {
if len(rootFile) == 0 {
rootFile = getRootCAPath()
}
client, err := NewClient(caURL, WithRootFile(rootFile))
if err != nil {
return nil, err
}
cursor := ""
var provisioners provisioner.List var provisioners provisioner.List
for { for {
resp, err := client.Provisioners(WithProvisionerCursor(cursor), WithProvisionerLimit(100)) resp, err := client.Provisioners(WithProvisionerCursor(cursor), WithProvisionerLimit(100))
@ -180,14 +184,7 @@ func getProvisioners(caURL, rootFile string) (provisioner.List, error) {
} }
// getProvisionerKey returns the encrypted provisioner key for the given kid. // getProvisionerKey returns the encrypted provisioner key for the given kid.
func getProvisionerKey(caURL, rootFile, kid string) (string, error) { func getProvisionerKey(client *Client, kid string) (string, error) {
if len(rootFile) == 0 {
rootFile = getRootCAPath()
}
client, err := NewClient(caURL, WithRootFile(rootFile))
if err != nil {
return "", err
}
resp, err := client.ProvisionerKey(kid) resp, err := client.ProvisionerKey(kid)
if err != nil { if err != nil {
return "", err return "", err

View file

@ -1,23 +1,39 @@
package ca package ca
import ( import (
"io/ioutil"
"net/url"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/smallstep/cli/crypto/pemutil"
"github.com/smallstep/cli/crypto/x509util"
"github.com/smallstep/cli/jose" "github.com/smallstep/cli/jose"
) )
func getTestProvisioner(t *testing.T, url string) *Provisioner { func getTestProvisioner(t *testing.T, caURL string) *Provisioner {
jwk, err := jose.ParseKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) jwk, err := jose.ParseKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password")))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
client, err := NewClient(caURL, WithRootFile("testdata/secrets/root_ca.crt"))
if err != nil {
t.Fatal(err)
}
return &Provisioner{ return &Provisioner{
Client: client,
name: "mariano", name: "mariano",
kid: "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", kid: "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
caURL: url, audience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/sign"}).String(),
caRoot: "testdata/secrets/root_ca.crt", fingerprint: x509util.Fingerprint(cert),
jwk: jwk, jwk: jwk,
tokenLifetime: 5 * time.Minute, tokenLifetime: 5 * time.Minute,
} }
@ -28,12 +44,17 @@ func TestNewProvisioner(t *testing.T) {
defer ca.Close() defer ca.Close()
want := getTestProvisioner(t, ca.URL) want := getTestProvisioner(t, ca.URL)
caBundle, err := ioutil.ReadFile("testdata/secrets/root_ca.crt")
if err != nil {
t.Fatal(err)
}
type args struct { type args struct {
name string name string
kid string kid string
caURL string caURL string
caRoot string password []byte
password []byte clientOption ClientOption
} }
tests := []struct { tests := []struct {
name string name string
@ -41,21 +62,30 @@ func TestNewProvisioner(t *testing.T) {
want *Provisioner want *Provisioner
wantErr bool wantErr bool
}{ }{
{"ok", args{want.name, want.kid, want.caURL, want.caRoot, []byte("password")}, want, false}, {"ok", args{want.name, want.kid, ca.URL, []byte("password"), WithRootFile("testdata/secrets/root_ca.crt")}, want, false},
{"ok-by-name", args{want.name, "", want.caURL, want.caRoot, []byte("password")}, want, false}, {"ok-by-name", args{want.name, "", ca.URL, []byte("password"), WithRootFile("testdata/secrets/root_ca.crt")}, want, false},
{"fail-bad-kid", args{want.name, "bad-kid", want.caURL, want.caRoot, []byte("password")}, nil, true}, {"ok-with-bundle", args{want.name, want.kid, ca.URL, []byte("password"), WithCABundle(caBundle)}, want, false},
{"fail-empty-name", args{"", want.kid, want.caURL, want.caRoot, []byte("password")}, nil, true}, {"ok-with-fingerprint", args{want.name, want.kid, ca.URL, []byte("password"), WithRootSHA256(want.fingerprint)}, want, false},
{"fail-bad-name", args{"bad-name", "", want.caURL, want.caRoot, []byte("password")}, nil, true}, {"fail-bad-kid", args{want.name, "bad-kid", ca.URL, []byte("password"), WithRootFile("testdata/secrets/root_ca.crt")}, nil, true},
{"fail-by-password", args{want.name, want.kid, want.caURL, want.caRoot, []byte("bad-password")}, nil, true}, {"fail-empty-name", args{"", want.kid, ca.URL, []byte("password"), WithRootFile("testdata/secrets/root_ca.crt")}, nil, true},
{"fail-by-password-no-kid", args{want.name, "", want.caURL, want.caRoot, []byte("bad-password")}, nil, true}, {"fail-bad-name", args{"bad-name", "", ca.URL, []byte("password"), WithRootFile("testdata/secrets/root_ca.crt")}, nil, true},
{"fail-by-password", args{want.name, want.kid, ca.URL, []byte("bad-password"), WithRootFile("testdata/secrets/root_ca.crt")}, nil, true},
{"fail-by-password-no-kid", args{want.name, "", ca.URL, []byte("bad-password"), WithRootFile("testdata/secrets/root_ca.crt")}, nil, true},
{"fail-bad-certificate", args{want.name, want.kid, ca.URL, []byte("password"), WithRootFile("testdata/secrets/federated_ca.crt")}, nil, true},
{"fail-not-found-certificate", args{want.name, want.kid, ca.URL, []byte("password"), WithRootFile("testdata/secrets/missing.crt")}, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := NewProvisioner(tt.args.name, tt.args.kid, tt.args.caURL, tt.args.caRoot, tt.args.password) got, err := NewProvisioner(tt.args.name, tt.args.kid, tt.args.caURL, tt.args.password, tt.args.clientOption)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("NewProvisioner() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NewProvisioner() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
// Client won't match.
// Make sure it does.
if got != nil {
got.Client = want.Client
}
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewProvisioner() = %v, want %v", got, tt.want) t.Errorf("NewProvisioner() = %v, want %v", got, tt.want)
} }
@ -80,13 +110,13 @@ func TestProvisioner_Token(t *testing.T) {
type fields struct { type fields struct {
name string name string
kid string kid string
caURL string fingerprint string
caRoot string
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
tokenLifetime time.Duration tokenLifetime time.Duration
} }
type args struct { type args struct {
subject string subject string
sans []string
} }
tests := []struct { tests := []struct {
name string name string
@ -94,21 +124,23 @@ func TestProvisioner_Token(t *testing.T) {
args args args args
wantErr bool wantErr bool
}{ }{
{"ok", fields{p.name, p.kid, p.caURL, p.caRoot, p.jwk, p.tokenLifetime}, args{"subject"}, false}, {"ok", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", nil}, false},
{"fail-no-subject", fields{p.name, p.kid, p.caURL, p.caRoot, p.jwk, p.tokenLifetime}, args{""}, true}, {"ok-with-san", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com"}}, false},
{"fail-no-key", fields{p.name, p.kid, p.caURL, p.caRoot, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject"}, true}, {"ok-with-sans", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com", "127.0.0.1"}}, false},
{"fail-no-subject", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"", []string{"foo.smallstep.com"}}, true},
{"fail-no-key", fields{p.name, p.kid, sha, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject", nil}, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p := &Provisioner{ p := &Provisioner{
name: tt.fields.name, name: tt.fields.name,
kid: tt.fields.kid, kid: tt.fields.kid,
caURL: tt.fields.caURL, audience: "https://127.0.0.1:9000/1.0/sign",
caRoot: tt.fields.caRoot, fingerprint: tt.fields.fingerprint,
jwk: tt.fields.jwk, jwk: tt.fields.jwk,
tokenLifetime: tt.fields.tokenLifetime, tokenLifetime: tt.fields.tokenLifetime,
} }
got, err := p.Token(tt.args.subject) got, err := p.Token(tt.args.subject, tt.args.sans...)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("Provisioner.Token() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Provisioner.Token() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -126,7 +158,7 @@ func TestProvisioner_Token(t *testing.T) {
return return
} }
if err := claims.ValidateWithLeeway(jose.Expected{ if err := claims.ValidateWithLeeway(jose.Expected{
Audience: []string{tt.fields.caURL + "/1.0/sign"}, Audience: []string{"https://127.0.0.1:9000/1.0/sign"},
Issuer: tt.fields.name, Issuer: tt.fields.name,
Subject: tt.args.subject, Subject: tt.args.subject,
Time: time.Now().UTC(), Time: time.Now().UTC(),
@ -146,8 +178,18 @@ func TestProvisioner_Token(t *testing.T) {
if v, ok := allClaims["sha"].(string); !ok || v != sha { if v, ok := allClaims["sha"].(string); !ok || v != sha {
t.Errorf("Claim sha = %s, want %s", v, sha) t.Errorf("Claim sha = %s, want %s", v, sha)
} }
if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) { if len(tt.args.sans) == 0 {
t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject}) if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) {
t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject})
}
} else {
want := []interface{}{}
for _, s := range tt.args.sans {
want = append(want, s)
}
if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, want) {
t.Errorf("Claim sans = %s, want %s", v, want)
}
} }
if v, ok := allClaims["jti"].(string); !ok || v == "" { if v, ok := allClaims["jti"].(string); !ok || v == "" {
t.Errorf("Claim jti = %s, want not blank", v) t.Errorf("Claim jti = %s, want not blank", v)