diff --git a/Gopkg.lock b/Gopkg.lock index 7628ffce..8a647e3d 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -338,7 +338,7 @@ [[projects]] branch = "master" - digest = "1:dad8c405e442687cf36b3d236290554168c78c0b71c0f2cef35635933a225fe3" + digest = "1:5be7f46f64720a346290a9aeb734521523ebe4bbc0180c1d9f4355fac578e3c3" name = "github.com/smallstep/cli" packages = [ "command", @@ -359,7 +359,7 @@ "utils", ] pruneopts = "UT" - revision = "8429a2f6f5d6f097b843322a9a8e80d6fd087258" + revision = "ea92c904da59b3892f53435d01902156ec5e171d" [[projects]] branch = "master" diff --git a/ca/client.go b/ca/client.go index ddf6ab86..c9766293 100644 --- a/ca/client.go +++ b/ca/client.go @@ -16,12 +16,15 @@ import ( "io/ioutil" "net/http" "net/url" + "os" + "path/filepath" "strconv" "strings" "github.com/pkg/errors" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority" + "github.com/smallstep/cli/config" "github.com/smallstep/cli/crypto/x509util" "gopkg.in/square/go-jose.v2/jwt" ) @@ -33,6 +36,7 @@ type clientOptions struct { transport http.RoundTripper rootSHA256 string rootFilename string + rootBundle []byte } func (o *clientOptions) apply(opts []ClientOption) (err error) { @@ -47,7 +51,7 @@ func (o *clientOptions) apply(opts []ClientOption) (err error) { // checkTransport checks if other ways to set up a transport have been provided. // If they have it returns an error. func (o *clientOptions) checkTransport() error { - if o.transport != nil || o.rootFilename != "" || o.rootSHA256 != "" { + if o.transport != nil || o.rootFilename != "" || o.rootSHA256 != "" || o.rootBundle != nil { return errors.New("multiple transport methods have been configured") } return nil @@ -68,14 +72,27 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err return nil, err } } + if o.rootBundle != nil { + if tr, err = getTransportFromCABundle(o.rootBundle); err != nil { + return nil, err + } + } + // As the last option attempt to load the default root ca if tr == nil { + rootFile := getRootCAPath() + if _, err := os.Stat(rootFile); err == nil { + if tr, err = getTransportFromFile(rootFile); err != nil { + return nil, err + } + return tr, nil + } return nil, errors.New("a transport, a root cert, or a root sha256 must be used") } return tr, nil } -// WithTransport adds a custom transport to the Client. If the transport is -// given is given it will have preference over WithRootFile and WithRootSHA256. +// WithTransport adds a custom transport to the Client. It will fail if a +// previous option to create the transport has been configured. func WithTransport(tr http.RoundTripper) ClientOption { return func(o *clientOptions) error { if err := o.checkTransport(); err != nil { @@ -86,9 +103,8 @@ func WithTransport(tr http.RoundTripper) ClientOption { } } -// WithRootFile will create the transport using the given root certificate. If -// the root file is given it will have preference over WithRootSHA256, but less -// preference than WithTransport. +// WithRootFile will create the transport using the given root certificate. It +// will fail if a previous option to create the transport has been configured. func WithRootFile(filename string) ClientOption { return func(o *clientOptions) error { if err := o.checkTransport(); err != nil { @@ -99,8 +115,9 @@ func WithRootFile(filename string) ClientOption { } } -// WithRootSHA256 will create the transport using an insecure client to retrieve the -// root certificate. It has less preference than WithTransport and WithRootFile. +// WithRootSHA256 will create the transport using an insecure client to retrieve +// the root certificate using its fingerprint. It will fail if a previous option +// to create the transport has been configured. func WithRootSHA256(sum string) ClientOption { return func(o *clientOptions) error { if err := o.checkTransport(); err != nil { @@ -111,6 +128,18 @@ func WithRootSHA256(sum string) ClientOption { } } +// WithCABundle will create the transport using the given root certificates. It +// will fail if a previous option to create the transport has been configured. +func WithCABundle(bundle []byte) ClientOption { + return func(o *clientOptions) error { + if err := o.checkTransport(); err != nil { + return err + } + o.rootBundle = bundle + return nil + } +} + func getTransportFromFile(filename string) (http.RoundTripper, error) { data, err := ioutil.ReadFile(filename) if err != nil { @@ -146,6 +175,18 @@ func getTransportFromSHA256(endpoint, sum string) (http.RoundTripper, error) { }) } +func getTransportFromCABundle(bundle []byte) (http.RoundTripper, error) { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(bundle) { + return nil, errors.New("error parsing ca bundle: no certificates found") + } + return getDefaultTransport(&tls.Config{ + MinVersion: tls.VersionTLS12, + PreferServerCipherSuites: true, + RootCAs: pool, + }) +} + // parseEndpoint parses and validates the given endpoint. It supports general // URLs like https://ca.smallstep.com[:port][/path], and incomplete URLs like // ca.smallstep.com[:port][/path]. @@ -464,6 +505,25 @@ func (c *Client) Federation() (*api.FederationResponse, error) { return &federation, nil } +// RootFingerprint is a helper method that returns the current root fingerprint. +// It does an health connection and gets the fingerprint from the TLS verified +// chains. +func (c *Client) RootFingerprint() (string, error) { + u := c.endpoint.ResolveReference(&url.URL{Path: "/health"}) + resp, err := c.client.Get(u.String()) + if err != nil { + return "", errors.Wrapf(err, "client GET %s failed", u) + } + if resp.TLS == nil || len(resp.TLS.VerifiedChains) == 0 { + return "", errors.New("missing verified chains") + } + lastChain := resp.TLS.VerifiedChains[len(resp.TLS.VerifiedChains)-1] + if len(lastChain) == 0 { + return "", errors.New("missing verified chains") + } + return x509util.Fingerprint(lastChain[len(lastChain)-1]), nil +} + // CreateSignRequest is a helper function that given an x509 OTT returns a // simple but secure sign request as well as the private key used. func CreateSignRequest(ott string) (*api.SignRequest, crypto.PrivateKey, error) { @@ -522,6 +582,12 @@ func getInsecureClient() *http.Client { } } +// getRootCAPath returns the path where the root CA is stored based on the +// STEPPATH environment variable. +func getRootCAPath() string { + return filepath.Join(config.StepPath(), "certs", "root_ca.crt") +} + func readJSON(r io.ReadCloser, v interface{}) error { defer r.Close() return json.NewDecoder(r).Decode(v) diff --git a/ca/client_test.go b/ca/client_test.go index c8fada9b..5fb9aae4 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -13,8 +13,10 @@ import ( "testing" "time" + "github.com/smallstep/assert" "github.com/smallstep/certificates/api" "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/crypto/x509util" ) const ( @@ -746,3 +748,66 @@ func Test_parseEndpoint(t *testing.T) { }) } } + +func TestClient_RootFingerprint(t *testing.T) { + ok := &api.HealthResponse{Status: "ok"} + nok := api.InternalServerError(fmt.Errorf("Internal Server Error")) + + httpsServer := httptest.NewTLSServer(nil) + defer httpsServer.Close() + httpsServerFingerprint := x509util.Fingerprint(httpsServer.Certificate()) + + httpServer := httptest.NewServer(nil) + defer httpServer.Close() + + tests := []struct { + name string + server *httptest.Server + response interface{} + responseCode int + want string + wantErr bool + }{ + {"ok", httpsServer, ok, 200, httpsServerFingerprint, false}, + {"ok with error", httpsServer, nok, 500, httpsServerFingerprint, false}, + {"fail", httpServer, ok, 200, "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr := tt.server.Client().Transport + c, err := NewClient(tt.server.URL, WithTransport(tr)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + tt.server.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(tt.responseCode) + api.JSON(w, tt.response) + }) + + got, err := c.RootFingerprint() + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.RootFingerprint() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Client.RootFingerprint() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestClient_RootFingerprintWithServer(t *testing.T) { + srv := startCABootstrapServer() + defer srv.Close() + + client, err := NewClient(srv.URL+"/sign", WithRootFile("testdata/secrets/root_ca.crt")) + assert.FatalError(t, err) + + fp, err := client.RootFingerprint() + assert.FatalError(t, err) + assert.Equals(t, "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7", fp) +} diff --git a/ca/provisioner.go b/ca/provisioner.go index e0c50362..bc1acb94 100644 --- a/ca/provisioner.go +++ b/ca/provisioner.go @@ -2,30 +2,27 @@ package ca import ( "encoding/json" - "fmt" - "path/filepath" + "net/url" "time" "github.com/pkg/errors" "github.com/smallstep/certificates/authority/provisioner" - "github.com/smallstep/cli/config" "github.com/smallstep/cli/crypto/randutil" "github.com/smallstep/cli/jose" "github.com/smallstep/cli/token" "github.com/smallstep/cli/token/provision" ) -const ( - tokenLifetime = 5 * time.Minute -) +const tokenLifetime = 5 * time.Minute // Provisioner is an authorized entity that can sign tokens necessary for // signature requests. type Provisioner struct { + *Client name string kid string - caURL string - caRoot string + audience string + fingerprint string jwk *jose.JSONWebKey tokenLifetime time.Duration } @@ -34,26 +31,36 @@ type Provisioner struct { // provisioner. The key identified by `kid` will be used if specified. If `kid` // is the empty string we'll use the first key for the named provisioner that // decrypts using `password`. -func NewProvisioner(name, kid, caURL, caRoot string, password []byte) (*Provisioner, error) { - var jwk *jose.JSONWebKey - var err error - switch { - case name == "": - return nil, errors.New("provisioner name cannot be empty") - case kid == "": - jwk, err = loadProvisionerJWKByName(name, caURL, caRoot, password) - default: - jwk, err = loadProvisionerJWKByKid(kid, caURL, caRoot, password) - } +func NewProvisioner(name, kid, caURL string, password []byte, opts ...ClientOption) (*Provisioner, error) { + client, err := NewClient(caURL, opts...) if err != nil { return nil, err } + // Get the fingerprint of the current connection + fp, err := client.RootFingerprint() + if err != nil { + return nil, err + } + + var jwk *jose.JSONWebKey + switch { + case name == "": + return nil, errors.New("provisioner name cannot be empty") + case kid == "": + jwk, err = loadProvisionerJWKByName(client, name, password) + default: + jwk, err = loadProvisionerJWKByKid(client, kid, password) + } + if err != nil { + return nil, err + } return &Provisioner{ + Client: client, name: name, kid: jwk.KeyID, - caURL: caURL, - caRoot: caRoot, + audience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/sign"}).String(), + fingerprint: fp, jwk: jwk, tokenLifetime: tokenLifetime, }, nil @@ -69,8 +76,17 @@ func (p *Provisioner) Kid() string { return p.kid } +// SetFingerprint overwrites the default fingerprint used. +func (p *Provisioner) SetFingerprint(sum string) { + p.fingerprint = sum +} + // Token generates a bootstrap token for a subject. -func (p *Provisioner) Token(subject string) (string, error) { +func (p *Provisioner) Token(subject string, sans ...string) (string, error) { + if len(sans) == 0 { + sans = []string{subject} + } + // A random jwt id will be used to identify duplicated tokens jwtID, err := randutil.Hex(64) // 256 bits if err != nil { @@ -79,16 +95,17 @@ func (p *Provisioner) Token(subject string) (string, error) { notBefore := time.Now() notAfter := notBefore.Add(tokenLifetime) - signURL := fmt.Sprintf("%v/1.0/sign", p.caURL) - tokOptions := []token.Options{ token.WithJWTID(jwtID), token.WithKid(p.kid), token.WithIssuer(p.name), - token.WithAudience(signURL), + token.WithAudience(p.audience), token.WithValidity(notBefore, notAfter), - token.WithRootCA(p.caRoot), - token.WithSANS([]string{subject}), + token.WithSANS(sans), + } + + if p.fingerprint != "" { + tokOptions = append(tokOptions, token.WithSHA(p.fingerprint)) } tok, err := provision.New(subject, tokOptions...) @@ -117,8 +134,8 @@ func decryptProvisionerJWK(encryptedKey string, password []byte) (*jose.JSONWebK // loadProvisionerJWKByKid retrieves a provisioner key from the CA by key ID and // decrypts it using the specified password. -func loadProvisionerJWKByKid(kid, caURL, caRoot string, password []byte) (*jose.JSONWebKey, error) { - encrypted, err := getProvisionerKey(caURL, caRoot, kid) +func loadProvisionerJWKByKid(client *Client, kid string, password []byte) (*jose.JSONWebKey, error) { + encrypted, err := getProvisionerKey(client, kid) if err != nil { return nil, err } @@ -129,8 +146,8 @@ func loadProvisionerJWKByKid(kid, caURL, caRoot string, password []byte) (*jose. // loadProvisionerJWKByName retrieves the list of provisioners and encrypted key then // returns the key of the first provisioner with a matching name that can be successfully // decrypted with the specified password. -func loadProvisionerJWKByName(name, caURL, caRoot string, password []byte) (key *jose.JSONWebKey, err error) { - provisioners, err := getProvisioners(caURL, caRoot) +func loadProvisionerJWKByName(client *Client, name string, password []byte) (key *jose.JSONWebKey, err error) { + provisioners, err := getProvisioners(client) if err != nil { err = errors.Wrap(err, "error getting the provisioners") return @@ -149,22 +166,9 @@ func loadProvisionerJWKByName(name, caURL, caRoot string, password []byte) (key return nil, errors.Errorf("provisioner '%s' not found (or your password is wrong)", name) } -// getRootCAPath returns the path where the root CA is stored based on the -// STEPPATH environment variable. -func getRootCAPath() string { - return filepath.Join(config.StepPath(), "certs", "root_ca.crt") -} - -// getProvisioners returns the map of provisioners on the given CA. -func getProvisioners(caURL, rootFile string) (provisioner.List, error) { - if len(rootFile) == 0 { - rootFile = getRootCAPath() - } - client, err := NewClient(caURL, WithRootFile(rootFile)) - if err != nil { - return nil, err - } - cursor := "" +// getProvisioners returns the list of provisioners using the configured client. +func getProvisioners(client *Client) (provisioner.List, error) { + var cursor string var provisioners provisioner.List for { resp, err := client.Provisioners(WithProvisionerCursor(cursor), WithProvisionerLimit(100)) @@ -180,14 +184,7 @@ func getProvisioners(caURL, rootFile string) (provisioner.List, error) { } // getProvisionerKey returns the encrypted provisioner key for the given kid. -func getProvisionerKey(caURL, rootFile, kid string) (string, error) { - if len(rootFile) == 0 { - rootFile = getRootCAPath() - } - client, err := NewClient(caURL, WithRootFile(rootFile)) - if err != nil { - return "", err - } +func getProvisionerKey(client *Client, kid string) (string, error) { resp, err := client.ProvisionerKey(kid) if err != nil { return "", err diff --git a/ca/provisioner_test.go b/ca/provisioner_test.go index bc8a2b68..02541d4a 100644 --- a/ca/provisioner_test.go +++ b/ca/provisioner_test.go @@ -1,23 +1,38 @@ package ca import ( + "net/url" "reflect" "testing" "time" + "github.com/smallstep/cli/crypto/pemutil" + "github.com/smallstep/cli/crypto/x509util" "github.com/smallstep/cli/jose" ) -func getTestProvisioner(t *testing.T, url string) *Provisioner { +func getTestProvisioner(t *testing.T, caURL string) *Provisioner { jwk, err := jose.ParseKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) if err != nil { t.Fatal(err) } + + cert, err := pemutil.ReadCertificate("testdata/secrets/root_ca.crt") + if err != nil { + t.Fatal(err) + } + + client, err := NewClient(caURL) + if err != nil { + t.Fatal(err) + } + return &Provisioner{ + Client: client, name: "mariano", kid: "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", - caURL: url, - caRoot: "testdata/secrets/root_ca.crt", + audience: client.endpoint.ResolveReference(&url.URL{Path: "/1.0/sign"}).String(), + fingerprint: x509util.Fingerprint(cert), jwk: jwk, tokenLifetime: 5 * time.Minute, } @@ -32,8 +47,8 @@ func TestNewProvisioner(t *testing.T) { name string kid string caURL string - caRoot string password []byte + caRoot string } tests := []struct { name string @@ -41,21 +56,27 @@ func TestNewProvisioner(t *testing.T) { want *Provisioner wantErr bool }{ - {"ok", args{want.name, want.kid, want.caURL, want.caRoot, []byte("password")}, want, false}, - {"ok-by-name", args{want.name, "", want.caURL, want.caRoot, []byte("password")}, want, false}, - {"fail-bad-kid", args{want.name, "bad-kid", want.caURL, want.caRoot, []byte("password")}, nil, true}, - {"fail-empty-name", args{"", want.kid, want.caURL, want.caRoot, []byte("password")}, nil, true}, - {"fail-bad-name", args{"bad-name", "", want.caURL, want.caRoot, []byte("password")}, nil, true}, - {"fail-by-password", args{want.name, want.kid, want.caURL, want.caRoot, []byte("bad-password")}, nil, true}, - {"fail-by-password-no-kid", args{want.name, "", want.caURL, want.caRoot, []byte("bad-password")}, nil, true}, + {"ok", args{want.name, want.kid, ca.URL, []byte("password"), "testdata/secrets/root_ca.crt"}, want, false}, + {"ok-by-name", args{want.name, "", ca.URL, []byte("password"), "testdata/secrets/root_ca.crt"}, want, false}, + {"fail-bad-kid", args{want.name, "bad-kid", ca.URL, []byte("password"), "testdata/secrets/root_ca.crt"}, nil, true}, + {"fail-empty-name", args{"", want.kid, ca.URL, []byte("password"), "testdata/secrets/root_ca.crt"}, nil, true}, + {"fail-bad-name", args{"bad-name", "", ca.URL, []byte("password"), "testdata/secrets/root_ca.crt"}, nil, true}, + {"fail-by-password", args{want.name, want.kid, ca.URL, []byte("bad-password"), "testdata/secrets/root_ca.crt"}, nil, true}, + {"fail-by-password-no-kid", args{want.name, "", ca.URL, []byte("bad-password"), "testdata/secrets/root_ca.crt"}, nil, true}, + {"fail-bad-certificate", args{want.name, want.kid, ca.URL, []byte("password"), "testdata/secrets/federatec_ca.crt"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewProvisioner(tt.args.name, tt.args.kid, tt.args.caURL, tt.args.caRoot, tt.args.password) + got, err := NewProvisioner(tt.args.name, tt.args.kid, tt.args.caURL, tt.args.password, WithRootFile(tt.args.caRoot)) if (err != nil) != tt.wantErr { t.Errorf("NewProvisioner() error = %v, wantErr %v", err, tt.wantErr) return } + // Client won't match. + // Make sure it does. + if got != nil { + got.Client = want.Client + } if !reflect.DeepEqual(got, tt.want) { t.Errorf("NewProvisioner() = %v, want %v", got, tt.want) } @@ -80,13 +101,13 @@ func TestProvisioner_Token(t *testing.T) { type fields struct { name string kid string - caURL string - caRoot string + fingerprint string jwk *jose.JSONWebKey tokenLifetime time.Duration } type args struct { subject string + sans []string } tests := []struct { name string @@ -94,21 +115,23 @@ func TestProvisioner_Token(t *testing.T) { args args wantErr bool }{ - {"ok", fields{p.name, p.kid, p.caURL, p.caRoot, p.jwk, p.tokenLifetime}, args{"subject"}, false}, - {"fail-no-subject", fields{p.name, p.kid, p.caURL, p.caRoot, p.jwk, p.tokenLifetime}, args{""}, true}, - {"fail-no-key", fields{p.name, p.kid, p.caURL, p.caRoot, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject"}, true}, + {"ok", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", nil}, false}, + {"ok-with-san", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com"}}, false}, + {"ok-with-sans", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"subject", []string{"foo.smallstep.com", "127.0.0.1"}}, false}, + {"fail-no-subject", fields{p.name, p.kid, sha, p.jwk, p.tokenLifetime}, args{"", []string{"foo.smallstep.com"}}, true}, + {"fail-no-key", fields{p.name, p.kid, sha, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject", nil}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { p := &Provisioner{ name: tt.fields.name, kid: tt.fields.kid, - caURL: tt.fields.caURL, - caRoot: tt.fields.caRoot, + audience: "https://127.0.0.1:9000/1.0/sign", + fingerprint: tt.fields.fingerprint, jwk: tt.fields.jwk, tokenLifetime: tt.fields.tokenLifetime, } - got, err := p.Token(tt.args.subject) + got, err := p.Token(tt.args.subject, tt.args.sans...) if (err != nil) != tt.wantErr { t.Errorf("Provisioner.Token() error = %v, wantErr %v", err, tt.wantErr) return @@ -126,7 +149,7 @@ func TestProvisioner_Token(t *testing.T) { return } if err := claims.ValidateWithLeeway(jose.Expected{ - Audience: []string{tt.fields.caURL + "/1.0/sign"}, + Audience: []string{"https://127.0.0.1:9000/1.0/sign"}, Issuer: tt.fields.name, Subject: tt.args.subject, Time: time.Now().UTC(), @@ -146,8 +169,18 @@ func TestProvisioner_Token(t *testing.T) { if v, ok := allClaims["sha"].(string); !ok || v != sha { t.Errorf("Claim sha = %s, want %s", v, sha) } - if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) { - t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject}) + if len(tt.args.sans) == 0 { + if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) { + t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject}) + } + } else { + want := []interface{}{} + for _, s := range tt.args.sans { + want = append(want, s) + } + if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, want) { + t.Errorf("Claim sans = %s, want %s", v, want) + } } if v, ok := allClaims["jti"].(string); !ok || v == "" { t.Errorf("Claim jti = %s, want not blank", v)