diff --git a/ca/client.go b/ca/client.go index 78a667f2..ddf6ab86 100644 --- a/ca/client.go +++ b/ca/client.go @@ -258,6 +258,11 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) { }, nil } +// SetTransport updates the transport of the internal HTTP client. +func (c *Client) SetTransport(tr http.RoundTripper) { + c.client.Transport = tr +} + // Health performs the health request to the CA and returns the // api.HealthResponse struct. func (c *Client) Health() (*api.HealthResponse, error) { diff --git a/ca/provisioner.go b/ca/provisioner.go new file mode 100644 index 00000000..e0c50362 --- /dev/null +++ b/ca/provisioner.go @@ -0,0 +1,196 @@ +package ca + +import ( + "encoding/json" + "fmt" + "path/filepath" + "time" + + "github.com/pkg/errors" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/cli/config" + "github.com/smallstep/cli/crypto/randutil" + "github.com/smallstep/cli/jose" + "github.com/smallstep/cli/token" + "github.com/smallstep/cli/token/provision" +) + +const ( + tokenLifetime = 5 * time.Minute +) + +// Provisioner is an authorized entity that can sign tokens necessary for +// signature requests. +type Provisioner struct { + name string + kid string + caURL string + caRoot string + jwk *jose.JSONWebKey + tokenLifetime time.Duration +} + +// NewProvisioner loads and decrypts key material from the CA for the named +// provisioner. The key identified by `kid` will be used if specified. If `kid` +// is the empty string we'll use the first key for the named provisioner that +// decrypts using `password`. +func NewProvisioner(name, kid, caURL, caRoot string, password []byte) (*Provisioner, error) { + var jwk *jose.JSONWebKey + var err error + switch { + case name == "": + return nil, errors.New("provisioner name cannot be empty") + case kid == "": + jwk, err = loadProvisionerJWKByName(name, caURL, caRoot, password) + default: + jwk, err = loadProvisionerJWKByKid(kid, caURL, caRoot, password) + } + if err != nil { + return nil, err + } + + return &Provisioner{ + name: name, + kid: jwk.KeyID, + caURL: caURL, + caRoot: caRoot, + jwk: jwk, + tokenLifetime: tokenLifetime, + }, nil +} + +// Name returns the provisioner's name. +func (p *Provisioner) Name() string { + return p.name +} + +// Kid returns the provisioners key ID. +func (p *Provisioner) Kid() string { + return p.kid +} + +// Token generates a bootstrap token for a subject. +func (p *Provisioner) Token(subject string) (string, error) { + // A random jwt id will be used to identify duplicated tokens + jwtID, err := randutil.Hex(64) // 256 bits + if err != nil { + return "", err + } + + notBefore := time.Now() + notAfter := notBefore.Add(tokenLifetime) + signURL := fmt.Sprintf("%v/1.0/sign", p.caURL) + + tokOptions := []token.Options{ + token.WithJWTID(jwtID), + token.WithKid(p.kid), + token.WithIssuer(p.name), + token.WithAudience(signURL), + token.WithValidity(notBefore, notAfter), + token.WithRootCA(p.caRoot), + token.WithSANS([]string{subject}), + } + + tok, err := provision.New(subject, tokOptions...) + if err != nil { + return "", err + } + + return tok.SignedString(p.jwk.Algorithm, p.jwk.Key) +} + +func decryptProvisionerJWK(encryptedKey string, password []byte) (*jose.JSONWebKey, error) { + enc, err := jose.ParseEncrypted(encryptedKey) + if err != nil { + return nil, err + } + data, err := enc.Decrypt(password) + if err != nil { + return nil, err + } + jwk := new(jose.JSONWebKey) + if err := json.Unmarshal(data, jwk); err != nil { + return nil, errors.Wrap(err, "error unmarshaling provisioning key") + } + return jwk, nil +} + +// loadProvisionerJWKByKid retrieves a provisioner key from the CA by key ID and +// decrypts it using the specified password. +func loadProvisionerJWKByKid(kid, caURL, caRoot string, password []byte) (*jose.JSONWebKey, error) { + encrypted, err := getProvisionerKey(caURL, caRoot, kid) + if err != nil { + return nil, err + } + + return decryptProvisionerJWK(encrypted, password) +} + +// loadProvisionerJWKByName retrieves the list of provisioners and encrypted key then +// returns the key of the first provisioner with a matching name that can be successfully +// decrypted with the specified password. +func loadProvisionerJWKByName(name, caURL, caRoot string, password []byte) (key *jose.JSONWebKey, err error) { + provisioners, err := getProvisioners(caURL, caRoot) + if err != nil { + err = errors.Wrap(err, "error getting the provisioners") + return + } + + for _, provisioner := range provisioners { + if provisioner.GetName() == name { + if _, encryptedKey, ok := provisioner.GetEncryptedKey(); ok { + key, err = decryptProvisionerJWK(encryptedKey, password) + if err == nil { + return + } + } + } + } + return nil, errors.Errorf("provisioner '%s' not found (or your password is wrong)", name) +} + +// getRootCAPath returns the path where the root CA is stored based on the +// STEPPATH environment variable. +func getRootCAPath() string { + return filepath.Join(config.StepPath(), "certs", "root_ca.crt") +} + +// getProvisioners returns the map of provisioners on the given CA. +func getProvisioners(caURL, rootFile string) (provisioner.List, error) { + if len(rootFile) == 0 { + rootFile = getRootCAPath() + } + client, err := NewClient(caURL, WithRootFile(rootFile)) + if err != nil { + return nil, err + } + cursor := "" + var provisioners provisioner.List + for { + resp, err := client.Provisioners(WithProvisionerCursor(cursor), WithProvisionerLimit(100)) + if err != nil { + return nil, err + } + provisioners = append(provisioners, resp.Provisioners...) + if resp.NextCursor == "" { + return provisioners, nil + } + cursor = resp.NextCursor + } +} + +// getProvisionerKey returns the encrypted provisioner key for the given kid. +func getProvisionerKey(caURL, rootFile, kid string) (string, error) { + if len(rootFile) == 0 { + rootFile = getRootCAPath() + } + client, err := NewClient(caURL, WithRootFile(rootFile)) + if err != nil { + return "", err + } + resp, err := client.ProvisionerKey(kid) + if err != nil { + return "", err + } + return resp.Key, nil +} diff --git a/ca/provisioner_test.go b/ca/provisioner_test.go new file mode 100644 index 00000000..bc8a2b68 --- /dev/null +++ b/ca/provisioner_test.go @@ -0,0 +1,158 @@ +package ca + +import ( + "reflect" + "testing" + "time" + + "github.com/smallstep/cli/jose" +) + +func getTestProvisioner(t *testing.T, url string) *Provisioner { + jwk, err := jose.ParseKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password"))) + if err != nil { + t.Fatal(err) + } + return &Provisioner{ + name: "mariano", + kid: "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg", + caURL: url, + caRoot: "testdata/secrets/root_ca.crt", + jwk: jwk, + tokenLifetime: 5 * time.Minute, + } +} + +func TestNewProvisioner(t *testing.T) { + ca := startCATestServer() + defer ca.Close() + want := getTestProvisioner(t, ca.URL) + + type args struct { + name string + kid string + caURL string + caRoot string + password []byte + } + tests := []struct { + name string + args args + want *Provisioner + wantErr bool + }{ + {"ok", args{want.name, want.kid, want.caURL, want.caRoot, []byte("password")}, want, false}, + {"ok-by-name", args{want.name, "", want.caURL, want.caRoot, []byte("password")}, want, false}, + {"fail-bad-kid", args{want.name, "bad-kid", want.caURL, want.caRoot, []byte("password")}, nil, true}, + {"fail-empty-name", args{"", want.kid, want.caURL, want.caRoot, []byte("password")}, nil, true}, + {"fail-bad-name", args{"bad-name", "", want.caURL, want.caRoot, []byte("password")}, nil, true}, + {"fail-by-password", args{want.name, want.kid, want.caURL, want.caRoot, []byte("bad-password")}, nil, true}, + {"fail-by-password-no-kid", args{want.name, "", want.caURL, want.caRoot, []byte("bad-password")}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewProvisioner(tt.args.name, tt.args.kid, tt.args.caURL, tt.args.caRoot, tt.args.password) + if (err != nil) != tt.wantErr { + t.Errorf("NewProvisioner() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewProvisioner() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestProvisioner_Getters(t *testing.T) { + p := getTestProvisioner(t, "https://127.0.0.1:9000") + if got := p.Name(); got != p.name { + t.Errorf("Provisioner.Name() = %v, want %v", got, p.name) + } + if got := p.Kid(); got != p.kid { + t.Errorf("Provisioner.Kid() = %v, want %v", got, p.kid) + } +} + +func TestProvisioner_Token(t *testing.T) { + p := getTestProvisioner(t, "https://127.0.0.1:9000") + sha := "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7" + + type fields struct { + name string + kid string + caURL string + caRoot string + jwk *jose.JSONWebKey + tokenLifetime time.Duration + } + type args struct { + subject string + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{p.name, p.kid, p.caURL, p.caRoot, p.jwk, p.tokenLifetime}, args{"subject"}, false}, + {"fail-no-subject", fields{p.name, p.kid, p.caURL, p.caRoot, p.jwk, p.tokenLifetime}, args{""}, true}, + {"fail-no-key", fields{p.name, p.kid, p.caURL, p.caRoot, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Provisioner{ + name: tt.fields.name, + kid: tt.fields.kid, + caURL: tt.fields.caURL, + caRoot: tt.fields.caRoot, + jwk: tt.fields.jwk, + tokenLifetime: tt.fields.tokenLifetime, + } + got, err := p.Token(tt.args.subject) + if (err != nil) != tt.wantErr { + t.Errorf("Provisioner.Token() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr == false { + jwt, err := jose.ParseSigned(got) + if err != nil { + t.Error(err) + return + } + var claims jose.Claims + if err := jwt.Claims(tt.fields.jwk.Public(), &claims); err != nil { + t.Error(err) + return + } + if err := claims.ValidateWithLeeway(jose.Expected{ + Audience: []string{tt.fields.caURL + "/1.0/sign"}, + Issuer: tt.fields.name, + Subject: tt.args.subject, + Time: time.Now().UTC(), + }, time.Minute); err != nil { + t.Error(err) + return + } + lifetime := claims.Expiry.Time().Sub(claims.NotBefore.Time()) + if lifetime != tt.fields.tokenLifetime { + t.Errorf("Claims token life time = %s, want %s", lifetime, tt.fields.tokenLifetime) + } + allClaims := make(map[string]interface{}) + if err := jwt.Claims(tt.fields.jwk.Public(), &allClaims); err != nil { + t.Error(err) + return + } + if v, ok := allClaims["sha"].(string); !ok || v != sha { + t.Errorf("Claim sha = %s, want %s", v, sha) + } + if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) { + t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject}) + } + if v, ok := allClaims["jti"].(string); !ok || v == "" { + t.Errorf("Claim jti = %s, want not blank", v) + } + } + }) + } +} diff --git a/ca/signal.go b/ca/signal.go index 1b74ac4a..0d950435 100644 --- a/ca/signal.go +++ b/ca/signal.go @@ -7,6 +7,12 @@ import ( "syscall" ) +// Stopper is the interface that external commands can implement to stop the +// server. +type Stopper interface { + Stop() error +} + // StopReloader is the interface that external commands can implement to stop // the server and reload the configuration while running. type StopReloader interface { @@ -14,6 +20,32 @@ type StopReloader interface { Reload() error } +// StopHandler watches SIGINT, SIGTERM on a list of servers implementing the +// Stopper interface, and when one of those signals is caught we'll run Stop +// (SIGINT, SIGTERM) on all servers. +func StopHandler(servers ...Stopper) { + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + defer signal.Stop(signals) + + for { + select { + case sig := <-signals: + switch sig { + case syscall.SIGINT, syscall.SIGTERM: + log.Println("shutting down ...") + for _, server := range servers { + err := server.Stop() + if err != nil { + log.Printf("error stopping server: %s", err.Error()) + } + } + return + } + } + } +} + // StopReloaderHandler watches SIGINT, SIGTERM and SIGHUP on a list of servers // implementing the StopReloader interface, and when one of those signals is // caught we'll run Stop (SIGINT, SIGTERM) or Reload (SIGHUP) on all servers. diff --git a/ca/tls.go b/ca/tls.go index ef3af548..79493eb1 100644 --- a/ca/tls.go +++ b/ca/tls.go @@ -60,7 +60,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse, renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) // Update client transport - c.client.Transport = tr + c.SetTransport(tr) // Start renewer renewer.RunContext(ctx) @@ -111,7 +111,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse, renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) // Update client transport - c.client.Transport = tr + c.SetTransport(tr) // Start renewer renewer.RunContext(ctx)