forked from TrueCloudLab/certificates
Merge pull request #64 from smallstep/step-sds
Add token generator to ca package
This commit is contained in:
commit
6af1e95c5b
5 changed files with 393 additions and 2 deletions
|
@ -258,6 +258,11 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTransport updates the transport of the internal HTTP client.
|
||||||
|
func (c *Client) SetTransport(tr http.RoundTripper) {
|
||||||
|
c.client.Transport = tr
|
||||||
|
}
|
||||||
|
|
||||||
// Health performs the health request to the CA and returns the
|
// Health performs the health request to the CA and returns the
|
||||||
// api.HealthResponse struct.
|
// api.HealthResponse struct.
|
||||||
func (c *Client) Health() (*api.HealthResponse, error) {
|
func (c *Client) Health() (*api.HealthResponse, error) {
|
||||||
|
|
196
ca/provisioner.go
Normal file
196
ca/provisioner.go
Normal file
|
@ -0,0 +1,196 @@
|
||||||
|
package ca
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/certificates/authority/provisioner"
|
||||||
|
"github.com/smallstep/cli/config"
|
||||||
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
"github.com/smallstep/cli/token"
|
||||||
|
"github.com/smallstep/cli/token/provision"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tokenLifetime = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provisioner is an authorized entity that can sign tokens necessary for
|
||||||
|
// signature requests.
|
||||||
|
type Provisioner struct {
|
||||||
|
name string
|
||||||
|
kid string
|
||||||
|
caURL string
|
||||||
|
caRoot string
|
||||||
|
jwk *jose.JSONWebKey
|
||||||
|
tokenLifetime time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProvisioner loads and decrypts key material from the CA for the named
|
||||||
|
// provisioner. The key identified by `kid` will be used if specified. If `kid`
|
||||||
|
// is the empty string we'll use the first key for the named provisioner that
|
||||||
|
// decrypts using `password`.
|
||||||
|
func NewProvisioner(name, kid, caURL, caRoot string, password []byte) (*Provisioner, error) {
|
||||||
|
var jwk *jose.JSONWebKey
|
||||||
|
var err error
|
||||||
|
switch {
|
||||||
|
case name == "":
|
||||||
|
return nil, errors.New("provisioner name cannot be empty")
|
||||||
|
case kid == "":
|
||||||
|
jwk, err = loadProvisionerJWKByName(name, caURL, caRoot, password)
|
||||||
|
default:
|
||||||
|
jwk, err = loadProvisionerJWKByKid(kid, caURL, caRoot, password)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Provisioner{
|
||||||
|
name: name,
|
||||||
|
kid: jwk.KeyID,
|
||||||
|
caURL: caURL,
|
||||||
|
caRoot: caRoot,
|
||||||
|
jwk: jwk,
|
||||||
|
tokenLifetime: tokenLifetime,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the provisioner's name.
|
||||||
|
func (p *Provisioner) Name() string {
|
||||||
|
return p.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kid returns the provisioners key ID.
|
||||||
|
func (p *Provisioner) Kid() string {
|
||||||
|
return p.kid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token generates a bootstrap token for a subject.
|
||||||
|
func (p *Provisioner) Token(subject string) (string, error) {
|
||||||
|
// A random jwt id will be used to identify duplicated tokens
|
||||||
|
jwtID, err := randutil.Hex(64) // 256 bits
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
notBefore := time.Now()
|
||||||
|
notAfter := notBefore.Add(tokenLifetime)
|
||||||
|
signURL := fmt.Sprintf("%v/1.0/sign", p.caURL)
|
||||||
|
|
||||||
|
tokOptions := []token.Options{
|
||||||
|
token.WithJWTID(jwtID),
|
||||||
|
token.WithKid(p.kid),
|
||||||
|
token.WithIssuer(p.name),
|
||||||
|
token.WithAudience(signURL),
|
||||||
|
token.WithValidity(notBefore, notAfter),
|
||||||
|
token.WithRootCA(p.caRoot),
|
||||||
|
token.WithSANS([]string{subject}),
|
||||||
|
}
|
||||||
|
|
||||||
|
tok, err := provision.New(subject, tokOptions...)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tok.SignedString(p.jwk.Algorithm, p.jwk.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decryptProvisionerJWK(encryptedKey string, password []byte) (*jose.JSONWebKey, error) {
|
||||||
|
enc, err := jose.ParseEncrypted(encryptedKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
data, err := enc.Decrypt(password)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jwk := new(jose.JSONWebKey)
|
||||||
|
if err := json.Unmarshal(data, jwk); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error unmarshaling provisioning key")
|
||||||
|
}
|
||||||
|
return jwk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadProvisionerJWKByKid retrieves a provisioner key from the CA by key ID and
|
||||||
|
// decrypts it using the specified password.
|
||||||
|
func loadProvisionerJWKByKid(kid, caURL, caRoot string, password []byte) (*jose.JSONWebKey, error) {
|
||||||
|
encrypted, err := getProvisionerKey(caURL, caRoot, kid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return decryptProvisionerJWK(encrypted, password)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadProvisionerJWKByName retrieves the list of provisioners and encrypted key then
|
||||||
|
// returns the key of the first provisioner with a matching name that can be successfully
|
||||||
|
// decrypted with the specified password.
|
||||||
|
func loadProvisionerJWKByName(name, caURL, caRoot string, password []byte) (key *jose.JSONWebKey, err error) {
|
||||||
|
provisioners, err := getProvisioners(caURL, caRoot)
|
||||||
|
if err != nil {
|
||||||
|
err = errors.Wrap(err, "error getting the provisioners")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, provisioner := range provisioners {
|
||||||
|
if provisioner.GetName() == name {
|
||||||
|
if _, encryptedKey, ok := provisioner.GetEncryptedKey(); ok {
|
||||||
|
key, err = decryptProvisionerJWK(encryptedKey, password)
|
||||||
|
if err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.Errorf("provisioner '%s' not found (or your password is wrong)", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRootCAPath returns the path where the root CA is stored based on the
|
||||||
|
// STEPPATH environment variable.
|
||||||
|
func getRootCAPath() string {
|
||||||
|
return filepath.Join(config.StepPath(), "certs", "root_ca.crt")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProvisioners returns the map of provisioners on the given CA.
|
||||||
|
func getProvisioners(caURL, rootFile string) (provisioner.List, error) {
|
||||||
|
if len(rootFile) == 0 {
|
||||||
|
rootFile = getRootCAPath()
|
||||||
|
}
|
||||||
|
client, err := NewClient(caURL, WithRootFile(rootFile))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cursor := ""
|
||||||
|
var provisioners provisioner.List
|
||||||
|
for {
|
||||||
|
resp, err := client.Provisioners(WithProvisionerCursor(cursor), WithProvisionerLimit(100))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
provisioners = append(provisioners, resp.Provisioners...)
|
||||||
|
if resp.NextCursor == "" {
|
||||||
|
return provisioners, nil
|
||||||
|
}
|
||||||
|
cursor = resp.NextCursor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProvisionerKey returns the encrypted provisioner key for the given kid.
|
||||||
|
func getProvisionerKey(caURL, rootFile, kid string) (string, error) {
|
||||||
|
if len(rootFile) == 0 {
|
||||||
|
rootFile = getRootCAPath()
|
||||||
|
}
|
||||||
|
client, err := NewClient(caURL, WithRootFile(rootFile))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
resp, err := client.ProvisionerKey(kid)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return resp.Key, nil
|
||||||
|
}
|
158
ca/provisioner_test.go
Normal file
158
ca/provisioner_test.go
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
package ca
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/smallstep/cli/jose"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getTestProvisioner(t *testing.T, url string) *Provisioner {
|
||||||
|
jwk, err := jose.ParseKey("testdata/secrets/ott_mariano_priv.jwk", jose.WithPassword([]byte("password")))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return &Provisioner{
|
||||||
|
name: "mariano",
|
||||||
|
kid: "FLIV7q23CXHrg75J2OSbvzwKJJqoxCYixjmsJirneOg",
|
||||||
|
caURL: url,
|
||||||
|
caRoot: "testdata/secrets/root_ca.crt",
|
||||||
|
jwk: jwk,
|
||||||
|
tokenLifetime: 5 * time.Minute,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewProvisioner(t *testing.T) {
|
||||||
|
ca := startCATestServer()
|
||||||
|
defer ca.Close()
|
||||||
|
want := getTestProvisioner(t, ca.URL)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
name string
|
||||||
|
kid string
|
||||||
|
caURL string
|
||||||
|
caRoot string
|
||||||
|
password []byte
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *Provisioner
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{want.name, want.kid, want.caURL, want.caRoot, []byte("password")}, want, false},
|
||||||
|
{"ok-by-name", args{want.name, "", want.caURL, want.caRoot, []byte("password")}, want, false},
|
||||||
|
{"fail-bad-kid", args{want.name, "bad-kid", want.caURL, want.caRoot, []byte("password")}, nil, true},
|
||||||
|
{"fail-empty-name", args{"", want.kid, want.caURL, want.caRoot, []byte("password")}, nil, true},
|
||||||
|
{"fail-bad-name", args{"bad-name", "", want.caURL, want.caRoot, []byte("password")}, nil, true},
|
||||||
|
{"fail-by-password", args{want.name, want.kid, want.caURL, want.caRoot, []byte("bad-password")}, nil, true},
|
||||||
|
{"fail-by-password-no-kid", args{want.name, "", want.caURL, want.caRoot, []byte("bad-password")}, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := NewProvisioner(tt.args.name, tt.args.kid, tt.args.caURL, tt.args.caRoot, tt.args.password)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("NewProvisioner() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("NewProvisioner() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvisioner_Getters(t *testing.T) {
|
||||||
|
p := getTestProvisioner(t, "https://127.0.0.1:9000")
|
||||||
|
if got := p.Name(); got != p.name {
|
||||||
|
t.Errorf("Provisioner.Name() = %v, want %v", got, p.name)
|
||||||
|
}
|
||||||
|
if got := p.Kid(); got != p.kid {
|
||||||
|
t.Errorf("Provisioner.Kid() = %v, want %v", got, p.kid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProvisioner_Token(t *testing.T) {
|
||||||
|
p := getTestProvisioner(t, "https://127.0.0.1:9000")
|
||||||
|
sha := "ef742f95dc0d8aa82d3cca4017af6dac3fce84290344159891952d18c53eefe7"
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
name string
|
||||||
|
kid string
|
||||||
|
caURL string
|
||||||
|
caRoot string
|
||||||
|
jwk *jose.JSONWebKey
|
||||||
|
tokenLifetime time.Duration
|
||||||
|
}
|
||||||
|
type args struct {
|
||||||
|
subject string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", fields{p.name, p.kid, p.caURL, p.caRoot, p.jwk, p.tokenLifetime}, args{"subject"}, false},
|
||||||
|
{"fail-no-subject", fields{p.name, p.kid, p.caURL, p.caRoot, p.jwk, p.tokenLifetime}, args{""}, true},
|
||||||
|
{"fail-no-key", fields{p.name, p.kid, p.caURL, p.caRoot, &jose.JSONWebKey{}, p.tokenLifetime}, args{"subject"}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Provisioner{
|
||||||
|
name: tt.fields.name,
|
||||||
|
kid: tt.fields.kid,
|
||||||
|
caURL: tt.fields.caURL,
|
||||||
|
caRoot: tt.fields.caRoot,
|
||||||
|
jwk: tt.fields.jwk,
|
||||||
|
tokenLifetime: tt.fields.tokenLifetime,
|
||||||
|
}
|
||||||
|
got, err := p.Token(tt.args.subject)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Provisioner.Token() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantErr == false {
|
||||||
|
jwt, err := jose.ParseSigned(got)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var claims jose.Claims
|
||||||
|
if err := jwt.Claims(tt.fields.jwk.Public(), &claims); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := claims.ValidateWithLeeway(jose.Expected{
|
||||||
|
Audience: []string{tt.fields.caURL + "/1.0/sign"},
|
||||||
|
Issuer: tt.fields.name,
|
||||||
|
Subject: tt.args.subject,
|
||||||
|
Time: time.Now().UTC(),
|
||||||
|
}, time.Minute); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lifetime := claims.Expiry.Time().Sub(claims.NotBefore.Time())
|
||||||
|
if lifetime != tt.fields.tokenLifetime {
|
||||||
|
t.Errorf("Claims token life time = %s, want %s", lifetime, tt.fields.tokenLifetime)
|
||||||
|
}
|
||||||
|
allClaims := make(map[string]interface{})
|
||||||
|
if err := jwt.Claims(tt.fields.jwk.Public(), &allClaims); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if v, ok := allClaims["sha"].(string); !ok || v != sha {
|
||||||
|
t.Errorf("Claim sha = %s, want %s", v, sha)
|
||||||
|
}
|
||||||
|
if v, ok := allClaims["sans"].([]interface{}); !ok || !reflect.DeepEqual(v, []interface{}{tt.args.subject}) {
|
||||||
|
t.Errorf("Claim sans = %s, want %s", v, []interface{}{tt.args.subject})
|
||||||
|
}
|
||||||
|
if v, ok := allClaims["jti"].(string); !ok || v == "" {
|
||||||
|
t.Errorf("Claim jti = %s, want not blank", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
32
ca/signal.go
32
ca/signal.go
|
@ -7,6 +7,12 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Stopper is the interface that external commands can implement to stop the
|
||||||
|
// server.
|
||||||
|
type Stopper interface {
|
||||||
|
Stop() error
|
||||||
|
}
|
||||||
|
|
||||||
// StopReloader is the interface that external commands can implement to stop
|
// StopReloader is the interface that external commands can implement to stop
|
||||||
// the server and reload the configuration while running.
|
// the server and reload the configuration while running.
|
||||||
type StopReloader interface {
|
type StopReloader interface {
|
||||||
|
@ -14,6 +20,32 @@ type StopReloader interface {
|
||||||
Reload() error
|
Reload() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StopHandler watches SIGINT, SIGTERM on a list of servers implementing the
|
||||||
|
// Stopper interface, and when one of those signals is caught we'll run Stop
|
||||||
|
// (SIGINT, SIGTERM) on all servers.
|
||||||
|
func StopHandler(servers ...Stopper) {
|
||||||
|
signals := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
defer signal.Stop(signals)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case sig := <-signals:
|
||||||
|
switch sig {
|
||||||
|
case syscall.SIGINT, syscall.SIGTERM:
|
||||||
|
log.Println("shutting down ...")
|
||||||
|
for _, server := range servers {
|
||||||
|
err := server.Stop()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("error stopping server: %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// StopReloaderHandler watches SIGINT, SIGTERM and SIGHUP on a list of servers
|
// StopReloaderHandler watches SIGINT, SIGTERM and SIGHUP on a list of servers
|
||||||
// implementing the StopReloader interface, and when one of those signals is
|
// implementing the StopReloader interface, and when one of those signals is
|
||||||
// caught we'll run Stop (SIGINT, SIGTERM) or Reload (SIGHUP) on all servers.
|
// caught we'll run Stop (SIGINT, SIGTERM) or Reload (SIGHUP) on all servers.
|
||||||
|
|
|
@ -60,7 +60,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
// Update client transport
|
// Update client transport
|
||||||
c.client.Transport = tr
|
c.SetTransport(tr)
|
||||||
|
|
||||||
// Start renewer
|
// Start renewer
|
||||||
renewer.RunContext(ctx)
|
renewer.RunContext(ctx)
|
||||||
|
@ -111,7 +111,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
|
||||||
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
|
||||||
|
|
||||||
// Update client transport
|
// Update client transport
|
||||||
c.client.Transport = tr
|
c.SetTransport(tr)
|
||||||
|
|
||||||
// Start renewer
|
// Start renewer
|
||||||
renewer.RunContext(ctx)
|
renewer.RunContext(ctx)
|
||||||
|
|
Loading…
Reference in a new issue