Merge pull request #692 from smallstep/max/context

Context management
This commit is contained in:
Max 2021-11-17 12:06:42 -08:00 committed by GitHub
commit de2ce5cf9f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 263 additions and 133 deletions

View file

@ -58,6 +58,7 @@ jobs:
run: V=1 make ci run: V=1 make ci
- -
name: Codecov name: Codecov
if: matrix.go == '1.17'
uses: codecov/codecov-action@v1.2.1 uses: codecov/codecov-action@v1.2.1
with: with:
file: ./coverage.out # optional file: ./coverage.out # optional

View file

@ -9,7 +9,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
) )
@ -245,7 +245,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
return "" return ""
} }
stepPath := filepath.ToSlash(config.StepPath()) stepPath := filepath.ToSlash(step.Path())
if !strings.HasSuffix(stepPath, "/") { if !strings.HasSuffix(stepPath, "/") {
stepPath += "/" stepPath += "/"
} }
@ -257,7 +257,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
panic(err) panic(err)
} }
if ok { if ok {
b, err := os.ReadFile(config.StepAbs(fn)) b, err := os.ReadFile(step.Abs(fn))
if err != nil { if err != nil {
panic(errors.Wrapf(err, "error reading %s", fn)) panic(errors.Wrapf(err, "error reading %s", fn))
} }

View file

@ -13,7 +13,7 @@ import (
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
step "go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/linkedca" "go.step.sm/linkedca"
@ -238,6 +238,8 @@ func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error {
return nil return nil
} }
// CreateFirstProvisioner creates and stores the first provisioner when using
// admin database provisioner storage.
func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) { func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) {
if password == "" { if password == "" {
pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one") pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one")
@ -287,6 +289,7 @@ func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (
return p, nil return p, nil
} }
// ValidateClaims validates the Claims type.
func ValidateClaims(c *linkedca.Claims) error { func ValidateClaims(c *linkedca.Claims) error {
if c == nil { if c == nil {
return nil return nil
@ -313,6 +316,7 @@ func ValidateClaims(c *linkedca.Claims) error {
return nil return nil
} }
// ValidateDurations validates the Durations type.
func ValidateDurations(d *linkedca.Durations) error { func ValidateDurations(d *linkedca.Durations) error {
var ( var (
err error err error
@ -523,7 +527,7 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.X509.Template != "" { if p.X509.Template != "" {
x509Template.Template = []byte(p.SSH.Template) x509Template.Template = []byte(p.SSH.Template)
} else if p.X509.TemplateFile != "" { } else if p.X509.TemplateFile != "" {
filename := step.StepAbs(p.X509.TemplateFile) filename := step.Abs(p.X509.TemplateFile)
if x509Template.Template, err = os.ReadFile(filename); err != nil { if x509Template.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading x509 template") return nil, nil, errors.Wrap(err, "error reading x509 template")
} }
@ -539,7 +543,7 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.SSH.Template != "" { if p.SSH.Template != "" {
sshTemplate.Template = []byte(p.SSH.Template) sshTemplate.Template = []byte(p.SSH.Template)
} else if p.SSH.TemplateFile != "" { } else if p.SSH.TemplateFile != "" {
filename := step.StepAbs(p.SSH.TemplateFile) filename := step.Abs(p.SSH.TemplateFile)
if sshTemplate.Template, err = os.ReadFile(filename); err != nil { if sshTemplate.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading ssh template") return nil, nil, errors.Wrap(err, "error reading ssh template")
} }

View file

@ -101,6 +101,15 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Backwards compatibility for version of the cli older than v0.18.0.
// Before v0.18.0 we were not passing any value for SSHTemplateVersionKey
// from the cli.
if o.Name == "step_includes.tpl" && data[templates.SSHTemplateVersionKey] == "" {
o.Type = templates.File
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/")
}
output = append(output, o) output = append(output, o)
} }
return output, nil return output, nil

View file

@ -501,6 +501,32 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")}, {Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")},
} }
tmplConfigUserIncludes := &templates.Templates{
SSH: &templates.SSHTemplates{
User: []templates.Template{
{Name: "step_includes.tpl", Type: templates.PrependLine, TemplatePath: "./testdata/templates/step_includes.tpl", Path: "${STEPPATH}/ssh/includes", Comment: "#"},
},
},
Data: map[string]interface{}{
"Step": &templates.Step{
SSH: templates.StepSSH{
UserKey: user,
HostKey: host,
},
},
},
}
userOutputEmptyData := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"<no value>/ssh/config\"\n")},
}
userOutputWithoutTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
userOutputWithTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.PrependLine, Comment: "#", Path: "${STEPPATH}/ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
tmplConfigErr := &templates.Templates{ tmplConfigErr := &templates.Templates{
SSH: &templates.SSHTemplates{ SSH: &templates.SSHTemplates{
User: []templates.Template{ User: []templates.Template{
@ -542,6 +568,9 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false}, {"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false},
{"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false}, {"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false},
{"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false}, {"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false},
{"userIncludesEmptyData", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", nil}, userOutputEmptyData, false},
{"userIncludesWithoutTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithoutTemplateVersion, false},
{"userIncludesWithTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step", "StepSSHTemplateVersion": "v2"}}, userOutputWithTemplateVersion, false},
{"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true}, {"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true},
{"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true}, {"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true},
{"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true}, {"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true},

View file

@ -0,0 +1 @@
{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}

View file

@ -28,7 +28,7 @@ import (
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/identity" "github.com/smallstep/certificates/ca/identity"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
@ -225,7 +225,7 @@ func (o *clientOptions) getTransport(endpoint string) (tr http.RoundTripper, err
return tr, nil return tr, nil
} }
// WithTransport adds a custom transport to the Client. It will fail if a // WithTransport adds a custom transport to the Client. It will fail if a
// previous option to create the transport has been configured. // previous option to create the transport has been configured.
func WithTransport(tr http.RoundTripper) ClientOption { func WithTransport(tr http.RoundTripper) ClientOption {
return func(o *clientOptions) error { return func(o *clientOptions) error {
@ -237,6 +237,17 @@ func WithTransport(tr http.RoundTripper) ClientOption {
} }
} }
// WithInsecure adds a insecure transport that bypasses TLS verification.
func WithInsecure() ClientOption {
return func(o *clientOptions) error {
o.transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
return nil
}
}
// WithRootFile will create the transport using the given root certificate. It // WithRootFile will create the transport using the given root certificate. It
// will fail if a previous option to create the transport has been configured. // will fail if a previous option to create the transport has been configured.
func WithRootFile(filename string) ClientOption { func WithRootFile(filename string) ClientOption {
@ -1294,7 +1305,7 @@ func createCertificateRequest(commonName string, sans []string, key crypto.Priva
// getRootCAPath returns the path where the root CA is stored based on the // getRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable. // STEPPATH environment variable.
func getRootCAPath() string { func getRootCAPath() string {
return filepath.Join(config.StepPath(), "certs", "root_ca.crt") return filepath.Join(step.Path(), "certs", "root_ca.crt")
} }
func readJSON(r io.ReadCloser, v interface{}) error { func readJSON(r io.ReadCloser, v interface{}) error {

View file

@ -27,21 +27,22 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL {
// $STEPPATH/config/defaults.json and the identity defined in // $STEPPATH/config/defaults.json and the identity defined in
// $STEPPATH/config/identity.json // $STEPPATH/config/identity.json
func LoadClient() (*Client, error) { func LoadClient() (*Client, error) {
b, err := os.ReadFile(DefaultsFile) defaultsFile := DefaultsFile()
b, err := os.ReadFile(defaultsFile)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", DefaultsFile) return nil, errors.Wrapf(err, "error reading %s", defaultsFile)
} }
var defaults defaultsConfig var defaults defaultsConfig
if err := json.Unmarshal(b, &defaults); err != nil { if err := json.Unmarshal(b, &defaults); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", DefaultsFile) return nil, errors.Wrapf(err, "error unmarshaling %s", defaultsFile)
} }
if err := defaults.Validate(); err != nil { if err := defaults.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
caURL, err := url.Parse(defaults.CaURL) caURL, err := url.Parse(defaults.CaURL)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
if caURL.Scheme == "" { if caURL.Scheme == "" {
caURL.Scheme = "https" caURL.Scheme = "https"
@ -52,7 +53,7 @@ func LoadClient() (*Client, error) {
return nil, err return nil, err
} }
if err := identity.Validate(); err != nil { if err := identity.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", IdentityFile) return nil, errors.Wrapf(err, "error validating %s", IdentityFile())
} }
if kind := identity.Kind(); kind != MutualTLS { if kind := identity.Kind(); kind != MutualTLS {
return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind) return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind)

View file

@ -11,6 +11,12 @@ import (
"testing" "testing"
) )
func returnInput(val string) func() string {
return func() string {
return val
}
}
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
oldIdentityFile := IdentityFile oldIdentityFile := IdentityFile
oldDefaultsFile := DefaultsFile oldDefaultsFile := DefaultsFile
@ -19,8 +25,8 @@ func TestClient(t *testing.T) {
DefaultsFile = oldDefaultsFile DefaultsFile = oldDefaultsFile
}() }()
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
client, err := LoadClient() client, err := LoadClient()
if err != nil { if err != nil {
@ -140,36 +146,36 @@ func TestLoadClient(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{"ok", func() { {"ok", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, expected, false}, }, expected, false},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/missing.json" IdentityFile = returnInput("testdata/config/missing.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/fail.json" IdentityFile = returnInput("testdata/config/fail.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/missing.json" DefaultsFile = returnInput("testdata/config/missing.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/fail.json" DefaultsFile = returnInput("testdata/config/fail.json")
}, nil, true}, }, nil, true},
{"fail ca", func() { {"fail ca", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badca.json" DefaultsFile = returnInput("testdata/config/badca.json")
}, nil, true}, }, nil, true},
{"fail root", func() { {"fail root", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badroot.json" DefaultsFile = returnInput("testdata/config/badroot.json")
}, nil, true}, }, nil, true},
{"fail type", func() { {"fail type", func() {
IdentityFile = "testdata/config/badIdentity.json" IdentityFile = returnInput("testdata/config/badIdentity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -15,7 +15,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
@ -38,11 +38,18 @@ const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims. // DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute const DefaultLeeway = 1 * time.Minute
// IdentityFile contains the location of the identity file. var (
var IdentityFile = filepath.Join(config.StepPath(), "config", "identity.json") identityDir = step.IdentityPath
configDir = step.ConfigPath
// DefaultsFile contains the location of the defaults file. // IdentityFile contains a pointer to a function that outputs the location of
var DefaultsFile = filepath.Join(config.StepPath(), "config", "defaults.json") // the identity file.
IdentityFile = step.IdentityFile
// DefaultsFile contains a prointer a function that outputs the location of the
// defaults configuration file.
DefaultsFile = step.DefaultsFile
)
// Identity represents the identity file that can be used to authenticate with // Identity represents the identity file that can be used to authenticate with
// the CA. // the CA.
@ -73,23 +80,17 @@ func LoadIdentity(filename string) (*Identity, error) {
// LoadDefaultIdentity loads the default identity. // LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) { func LoadDefaultIdentity() (*Identity, error) {
return LoadIdentity(IdentityFile) return LoadIdentity(IdentityFile())
} }
// configDir and identityDir are used in WriteDefaultIdentity for testing
// purposes.
var (
configDir = filepath.Join(config.StepPath(), "config")
identityDir = filepath.Join(config.StepPath(), "identity")
)
// WriteDefaultIdentity writes the given certificates and key and the // WriteDefaultIdentity writes the given certificates and key and the
// identity.json pointing to the new files. // identity.json pointing to the new files.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
if err := os.MkdirAll(configDir, 0700); err != nil { if err := os.MkdirAll(configDir(), 0700); err != nil {
return errors.Wrap(err, "error creating config directory") return errors.Wrap(err, "error creating config directory")
} }
identityDir := identityDir()
if err := os.MkdirAll(identityDir, 0700); err != nil { if err := os.MkdirAll(identityDir, 0700); err != nil {
return errors.Wrap(err, "error creating identity directory") return errors.Wrap(err, "error creating identity directory")
} }
@ -126,7 +127,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
}); err != nil { }); err != nil {
return errors.Wrap(err, "error writing identity json") return errors.Wrap(err, "error writing identity json")
} }
if err := os.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }
@ -135,7 +136,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
// WriteIdentityCertificate writes the identity certificate to disk. // WriteIdentityCertificate writes the identity certificate to disk.
func WriteIdentityCertificate(certChain []api.Certificate) error { func WriteIdentityCertificate(certChain []api.Certificate) error {
filename := filepath.Join(identityDir, "identity.crt") filename := filepath.Join(identityDir(), "identity.crt")
return writeCertificate(filename, certChain) return writeCertificate(filename, certChain)
} }
@ -318,7 +319,7 @@ func (i *Identity) Renew(client Renewer) error {
return errors.Wrap(err, "error encoding identity certificate") return errors.Wrap(err, "error encoding identity certificate")
} }
} }
certFilename := filepath.Join(identityDir, "identity.crt") certFilename := filepath.Join(identityDir(), "identity.crt")
if err := os.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }

View file

@ -33,9 +33,9 @@ func TestLoadDefaultIdentity(t *testing.T) {
want *Identity want *Identity
wantErr bool wantErr bool
}{ }{
{"ok", func() { IdentityFile = "testdata/config/identity.json" }, expected, false}, {"ok", func() { IdentityFile = returnInput("testdata/config/identity.json") }, expected, false},
{"fail read", func() { IdentityFile = "testdata/config/missing.json" }, nil, true}, {"fail read", func() { IdentityFile = returnInput("testdata/config/missing.json") }, nil, true},
{"fail unmarshal", func() { IdentityFile = "testdata/config/fail.json" }, nil, true}, {"fail unmarshal", func() { IdentityFile = returnInput("testdata/config/fail.json") }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -217,9 +217,9 @@ func TestWriteDefaultIdentity(t *testing.T) {
certChain = append(certChain, api.Certificate{Certificate: c}) certChain = append(certChain, api.Certificate{Certificate: c})
} }
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(tmpDir, "config", "identity.json") IdentityFile = returnInput(filepath.Join(tmpDir, "config", "identity.json"))
type args struct { type args struct {
certChain []api.Certificate certChain []api.Certificate
@ -233,27 +233,27 @@ func TestWriteDefaultIdentity(t *testing.T) {
}{ }{
{"ok", func() {}, args{certChain, key}, false}, {"ok", func() {}, args{certChain, key}, false},
{"fail mkdir config", func() { {"fail mkdir config", func() {
configDir = filepath.Join(tmpDir, "identity", "identity.crt") configDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail mkdir identity", func() { {"fail mkdir identity", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity", "identity.crt") identityDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail certificate", func() { {"fail certificate", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail key", func() { {"fail key", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, "badKey"}, true}, }, args{certChain, "badKey"}, true},
{"fail write identity", func() { {"fail write identity", func() {
configDir = filepath.Join(tmpDir, "bad-dir") configDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(configDir, "identity.json") IdentityFile = returnInput(filepath.Join(configDir(), "identity.json"))
os.MkdirAll(configDir, 0600) os.MkdirAll(configDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
} }
@ -377,7 +377,7 @@ func TestIdentity_Renew(t *testing.T) {
} }
oldIdentityDir := identityDir oldIdentityDir := identityDir
identityDir = "testdata/identity" identityDir = returnInput("testdata/identity")
defer func() { defer func() {
identityDir = oldIdentityDir identityDir = oldIdentityDir
os.RemoveAll(tmpDir) os.RemoveAll(tmpDir)
@ -432,8 +432,8 @@ func TestIdentity_Renew(t *testing.T) {
{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true}, {"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true}, {"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
{"fail write identity", func() { {"fail write identity", func() {
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true}, }, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -21,7 +21,7 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
"go.step.sm/cli-utils/command" "go.step.sm/cli-utils/command"
"go.step.sm/cli-utils/command/version" "go.step.sm/cli-utils/command/version"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/cli-utils/usage" "go.step.sm/cli-utils/usage"
@ -49,7 +49,7 @@ var (
) )
func init() { func init() {
config.Set("Smallstep CA", Version, BuildTime) step.Set("Smallstep CA", Version, BuildTime)
authority.GlobalVersion.Version = Version authority.GlobalVersion.Version = Version
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
} }
@ -115,7 +115,7 @@ func main() {
app := cli.NewApp() app := cli.NewApp()
app.Name = "step-ca" app.Name = "step-ca"
app.HelpName = "step-ca" app.HelpName = "step-ca"
app.Version = config.Version() app.Version = step.Version()
app.Usage = "an online certificate authority for secure automated certificate management" app.Usage = "an online certificate authority for secure automated certificate management"
app.UsageText = `**step-ca** <config> [**--password-file**=<file>] app.UsageText = `**step-ca** <config> [**--password-file**=<file>]
[**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>] [**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>]

13
go.mod
View file

@ -29,10 +29,10 @@ require (
github.com/rs/xid v1.2.1 github.com/rs/xid v1.2.1
github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus v1.4.2
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262
github.com/smallstep/nosql v0.3.8 github.com/smallstep/nosql v0.3.9
github.com/urfave/cli v1.22.4 github.com/urfave/cli v1.22.4
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352
go.step.sm/cli-utils v0.6.2 go.step.sm/cli-utils v0.7.0
go.step.sm/crypto v0.13.0 go.step.sm/crypto v0.13.0
go.step.sm/linkedca v0.7.0 go.step.sm/linkedca v0.7.0
golang.org/x/crypto v0.0.0-20210915214749-c084706c2272 golang.org/x/crypto v0.0.0-20210915214749-c084706c2272
@ -44,7 +44,8 @@ require (
gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/square/go-jose.v2 v2.6.0
) )
// replace github.com/smallstep/nosql => ../nosql //replace github.com/smallstep/nosql => ../nosql
// replace go.step.sm/crypto => ../crypto
// replace go.step.sm/cli-utils => ../cli-utils //replace go.step.sm/crypto => ../crypto
// replace go.step.sm/linkedca => ../linkedca
//replace go.step.sm/cli-utils => ../cli-utils

8
go.sum
View file

@ -494,8 +494,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd
github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE= github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc=
github.com/smallstep/nosql v0.3.8 h1:1/EWUbbEdz9ai0g9Fd09VekVjtxp+5+gIHpV2PdwW3o= github.com/smallstep/nosql v0.3.9 h1:YPy5PR3PXClqmpFaVv0wfXDXDc7NXGBE1auyU2c87dc=
github.com/smallstep/nosql v0.3.8/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs= github.com/smallstep/nosql v0.3.9/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
@ -559,8 +559,8 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.step.sm/cli-utils v0.6.2 h1:ofa3G/EqE3dTDXmzoXHDZr18qJZoFsKSzbzuF+mxuZU= go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds=
go.step.sm/cli-utils v0.6.2/go.mod h1:0tZ8F2QwLgD6KbKj4nrQZhMakTasEAnOcW3Ekc5pnrA= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E=
go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0=
go.step.sm/crypto v0.13.0 h1:mQuP9Uu2FNmqCJNO0OTbvolnYXzONy4wdUBtUVcP1s8= go.step.sm/crypto v0.13.0 h1:mQuP9Uu2FNmqCJNO0OTbvolnYXzONy4wdUBtUVcP1s8=
go.step.sm/crypto v0.13.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg= go.step.sm/crypto v0.13.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg=

View file

@ -29,9 +29,9 @@ import (
"github.com/smallstep/certificates/kms" "github.com/smallstep/certificates/kms"
kmsapi "github.com/smallstep/certificates/kms/apiv1" kmsapi "github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/errs"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
@ -87,44 +87,50 @@ const (
) )
// GetDBPath returns the path where the file-system persistence is stored // GetDBPath returns the path where the file-system persistence is stored
// based on the STEPPATH environment variable. // based on the $(step path).
func GetDBPath() string { func GetDBPath() string {
return filepath.Join(config.StepPath(), dbPath) return filepath.Join(step.Path(), dbPath)
} }
// GetConfigPath returns the directory where the configuration files are stored // GetConfigPath returns the directory where the configuration files are stored
// based on the STEPPATH environment variable. // based on the $(step path).
func GetConfigPath() string { func GetConfigPath() string {
return filepath.Join(config.StepPath(), configPath) return filepath.Join(step.Path(), configPath)
}
// GetProfileConfigPath returns the directory where the profile configuration
// files are stored based on the $(step path).
func GetProfileConfigPath() string {
return filepath.Join(step.ProfilePath(), configPath)
} }
// GetPublicPath returns the directory where the public keys are stored based on // GetPublicPath returns the directory where the public keys are stored based on
// the STEPPATH environment variable. // the $(step path).
func GetPublicPath() string { func GetPublicPath() string {
return filepath.Join(config.StepPath(), publicPath) return filepath.Join(step.Path(), publicPath)
} }
// GetSecretsPath returns the directory where the private keys are stored based // GetSecretsPath returns the directory where the private keys are stored based
// on the STEPPATH environment variable. // on the $(step path).
func GetSecretsPath() string { func GetSecretsPath() string {
return filepath.Join(config.StepPath(), privatePath) return filepath.Join(step.Path(), privatePath)
} }
// GetRootCAPath returns the path where the root CA is stored based on the // GetRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable. // $(step path).
func GetRootCAPath() string { func GetRootCAPath() string {
return filepath.Join(config.StepPath(), publicPath, "root_ca.crt") return filepath.Join(step.Path(), publicPath, "root_ca.crt")
} }
// GetOTTKeyPath returns the path where the one-time token key is stored based // GetOTTKeyPath returns the path where the one-time token key is stored based
// on the STEPPATH environment variable. // on the $(step path).
func GetOTTKeyPath() string { func GetOTTKeyPath() string {
return filepath.Join(config.StepPath(), privatePath, "ott_key") return filepath.Join(step.Path(), privatePath, "ott_key")
} }
// GetTemplatesPath returns the path where the templates are stored. // GetTemplatesPath returns the path where the templates are stored.
func GetTemplatesPath() string { func GetTemplatesPath() string {
return filepath.Join(config.StepPath(), templatesPath) return filepath.Join(step.Path(), templatesPath)
} }
// GetProvisioners returns the map of provisioners on the given CA. // GetProvisioners returns the map of provisioners on the given CA.
@ -286,20 +292,22 @@ func WithKeyURIs(rootKey, intermediateKey, hostKey, userKey string) Option {
// PKI represents the Public Key Infrastructure used by a certificate authority. // PKI represents the Public Key Infrastructure used by a certificate authority.
type PKI struct { type PKI struct {
linkedca.Configuration linkedca.Configuration
Defaults linkedca.Defaults Defaults linkedca.Defaults
casOptions apiv1.Options casOptions apiv1.Options
caService apiv1.CertificateAuthorityService caService apiv1.CertificateAuthorityService
caCreator apiv1.CertificateAuthorityCreator caCreator apiv1.CertificateAuthorityCreator
keyManager kmsapi.KeyManager keyManager kmsapi.KeyManager
config string config string
defaults string defaults string
ottPublicKey *jose.JSONWebKey profileDefaults string
ottPrivateKey *jose.JSONWebEncryption ottPublicKey *jose.JSONWebKey
options *options ottPrivateKey *jose.JSONWebEncryption
options *options
} }
// New creates a new PKI configuration. // New creates a new PKI configuration.
func New(o apiv1.Options, opts ...Option) (*PKI, error) { func New(o apiv1.Options, opts ...Option) (*PKI, error) {
currentCtx := step.Contexts().GetCurrent()
caService, err := cas.New(context.Background(), o) caService, err := cas.New(context.Background(), o)
if err != nil { if err != nil {
return nil, err return nil, err
@ -358,6 +366,9 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
cfg = GetConfigPath() cfg = GetConfigPath()
// Create directories // Create directories
dirs := []string{public, private, cfg, GetTemplatesPath()} dirs := []string{public, private, cfg, GetTemplatesPath()}
if currentCtx != nil {
dirs = append(dirs, GetProfileConfigPath())
}
for _, name := range dirs { for _, name := range dirs {
if _, err := os.Stat(name); os.IsNotExist(err) { if _, err := os.Stat(name); os.IsNotExist(err) {
if err = os.MkdirAll(name, 0700); err != nil { if err = os.MkdirAll(name, 0700); err != nil {
@ -415,6 +426,10 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
if p.defaults, err = getPath(cfg, "defaults.json"); err != nil { if p.defaults, err = getPath(cfg, "defaults.json"); err != nil {
return nil, err return nil, err
} }
if currentCtx != nil {
p.profileDefaults = currentCtx.ProfileDefaultsFile()
}
if p.config, err = getPath(cfg, "ca.json"); err != nil { if p.config, err = getPath(cfg, "ca.json"); err != nil {
return nil, err return nil, err
} }
@ -944,6 +959,18 @@ func (p *PKI) Save(opt ...ConfigOption) error {
if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil { if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil {
return errs.FileError(err, p.defaults) return errs.FileError(err, p.defaults)
} }
// If we're using contexts then write a blank object to the default profile
// configuration location.
if p.profileDefaults != "" {
if _, err := os.Stat(p.profileDefaults); os.IsNotExist(err) {
// Write with 0600 to be consistent with directories structure.
if err = fileutil.WriteFile(p.profileDefaults, []byte("{}"), 0600); err != nil {
return errs.FileError(err, p.profileDefaults)
}
} else if err != nil {
return errs.FileError(err, p.profileDefaults)
}
}
// Generate and write templates // Generate and write templates
if err := generateTemplates(cfg.Templates); err != nil { if err := generateTemplates(cfg.Templates); err != nil {
@ -958,6 +985,9 @@ func (p *PKI) Save(opt ...ConfigOption) error {
} }
ui.PrintSelected("Default configuration", p.defaults) ui.PrintSelected("Default configuration", p.defaults)
if p.profileDefaults != "" {
ui.PrintSelected("Default profile configuration", p.profileDefaults)
}
ui.PrintSelected("Certificate Authority configuration", p.config) ui.PrintSelected("Certificate Authority configuration", p.config)
if p.options.deploymentType != LinkedDeployment { if p.options.deploymentType != LinkedDeployment {
ui.Println() ui.Println()

View file

@ -6,9 +6,9 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/errs"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
) )
// getTemplates returns all the templates enabled // getTemplates returns all the templates enabled
@ -44,7 +44,7 @@ func generateTemplates(t *templates.Templates) error {
if !ok { if !ok {
return errors.Errorf("template %s does not exists", t.Name) return errors.Errorf("template %s does not exists", t.Name)
} }
if err := fileutil.WriteFile(config.StepAbs(t.TemplatePath), []byte(data), 0644); err != nil { if err := fileutil.WriteFile(step.Abs(t.TemplatePath), []byte(data), 0644); err != nil {
return err return err
} }
} }
@ -53,7 +53,7 @@ func generateTemplates(t *templates.Templates) error {
if !ok { if !ok {
return errors.Errorf("template %s does not exists", t.Name) return errors.Errorf("template %s does not exists", t.Name)
} }
if err := fileutil.WriteFile(config.StepAbs(t.TemplatePath), []byte(data), 0644); err != nil { if err := fileutil.WriteFile(step.Abs(t.TemplatePath), []byte(data), 0644); err != nil {
return err return err
} }
} }

View file

@ -9,8 +9,8 @@ import (
"github.com/Masterminds/sprig/v3" "github.com/Masterminds/sprig/v3"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
) )
// TemplateType defines how a template will be written in disk. // TemplateType defines how a template will be written in disk.
@ -19,6 +19,9 @@ type TemplateType string
const ( const (
// Snippet will mark a template as a part of a file. // Snippet will mark a template as a part of a file.
Snippet TemplateType = "snippet" Snippet TemplateType = "snippet"
// PrependLine is a template for prepending a single line to a file. If the
// line already exists in the file it will be removed first.
PrependLine TemplateType = "prepend-line"
// File will mark a templates as a full file. // File will mark a templates as a full file.
File TemplateType = "file" File TemplateType = "file"
// Directory will mark a template as a directory. // Directory will mark a template as a directory.
@ -98,7 +101,7 @@ func (t *SSHTemplates) Validate() (err error) {
return return
} }
// Template represents on template file. // Template represents a template file.
type Template struct { type Template struct {
*template.Template *template.Template
Name string `json:"name"` Name string `json:"name"`
@ -117,8 +120,8 @@ func (t *Template) Validate() error {
return nil return nil
case t.Name == "": case t.Name == "":
return errors.New("template name cannot be empty") return errors.New("template name cannot be empty")
case t.Type != Snippet && t.Type != File && t.Type != Directory: case t.Type != Snippet && t.Type != File && t.Type != Directory && t.Type != PrependLine:
return errors.Errorf("invalid template type %s, it must be %s, %s, or %s", t.Type, Snippet, File, Directory) return errors.Errorf("invalid template type %s, it must be %s, %s, %s, or %s", t.Type, Snippet, PrependLine, File, Directory)
case t.TemplatePath == "" && t.Type != Directory && len(t.Content) == 0: case t.TemplatePath == "" && t.Type != Directory && len(t.Content) == 0:
return errors.New("template template cannot be empty") return errors.New("template template cannot be empty")
case t.TemplatePath != "" && t.Type == Directory: case t.TemplatePath != "" && t.Type == Directory:
@ -131,7 +134,7 @@ func (t *Template) Validate() error {
if t.TemplatePath != "" { if t.TemplatePath != "" {
// Check for file // Check for file
st, err := os.Stat(config.StepAbs(t.TemplatePath)) st, err := os.Stat(step.Abs(t.TemplatePath))
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", t.TemplatePath) return errors.Wrapf(err, "error reading %s", t.TemplatePath)
} }
@ -165,7 +168,7 @@ func (t *Template) Load() error {
if t.Template == nil && t.Type != Directory { if t.Template == nil && t.Type != Directory {
switch { switch {
case t.TemplatePath != "": case t.TemplatePath != "":
filename := config.StepAbs(t.TemplatePath) filename := step.Abs(t.TemplatePath)
b, err := os.ReadFile(filename) b, err := os.ReadFile(filename)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", filename) return errors.Wrapf(err, "error reading %s", filename)
@ -246,7 +249,10 @@ type Output struct {
// Write writes the Output to the filesystem as a directory, file or snippet. // Write writes the Output to the filesystem as a directory, file or snippet.
func (o *Output) Write() error { func (o *Output) Write() error {
path := config.StepAbs(o.Path) // Replace ${STEPPATH} with the base step path.
o.Path = strings.ReplaceAll(o.Path, "${STEPPATH}", step.BasePath())
path := step.Abs(o.Path)
if o.Type == Directory { if o.Type == Directory {
return mkdir(path, 0700) return mkdir(path, 0700)
} }
@ -256,11 +262,17 @@ func (o *Output) Write() error {
return err return err
} }
if o.Type == File { switch o.Type {
case File:
return fileutil.WriteFile(path, o.Content, 0600) return fileutil.WriteFile(path, o.Content, 0600)
case Snippet:
return fileutil.WriteSnippet(path, o.Content, 0600)
case PrependLine:
return fileutil.PrependLine(path, o.Content, 0600)
default:
// Default to using a Snippet type if the type is not known.
return fileutil.WriteSnippet(path, o.Content, 0600)
} }
return fileutil.WriteSnippet(path, o.Content, 0600)
} }
func mkdir(path string, perm os.FileMode) error { func mkdir(path string, perm os.FileMode) error {

View file

@ -4,6 +4,10 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// SSHTemplateVersionKey is a key that can be submitted by a client to select
// the template version that will be returned by the server.
var SSHTemplateVersionKey = "StepSSHTemplateVersion"
// Step represents the default variables available in the CA. // Step represents the default variables available in the CA.
type Step struct { type Step struct {
SSH StepSSH SSH StepSSH
@ -22,16 +26,23 @@ type StepSSH struct {
var DefaultSSHTemplates = SSHTemplates{ var DefaultSSHTemplates = SSHTemplates{
User: []Template{ User: []Template{
{ {
Name: "include.tpl", Name: "config.tpl",
Type: Snippet, Type: Snippet,
TemplatePath: "templates/ssh/include.tpl", TemplatePath: "templates/ssh/config.tpl",
Path: "~/.ssh/config", Path: "~/.ssh/config",
Comment: "#", Comment: "#",
}, },
{ {
Name: "config.tpl", Name: "step_includes.tpl",
Type: PrependLine,
TemplatePath: "templates/ssh/step_includes.tpl",
Path: "${STEPPATH}/ssh/includes",
Comment: "#",
},
{
Name: "step_config.tpl",
Type: File, Type: File,
TemplatePath: "templates/ssh/config.tpl", TemplatePath: "templates/ssh/step_config.tpl",
Path: "ssh/config", Path: "ssh/config",
Comment: "#", Comment: "#",
}, },
@ -64,30 +75,43 @@ var DefaultSSHTemplates = SSHTemplates{
// DefaultSSHTemplateData contains the data of the default templates used on ssh. // DefaultSSHTemplateData contains the data of the default templates used on ssh.
var DefaultSSHTemplateData = map[string]string{ var DefaultSSHTemplateData = map[string]string{
// include.tpl adds the step ssh config file. // base_config.tpl adds the step ssh config file.
// //
// Note: on windows `Include C:\...` is treated as a relative path. // Note: on windows `Include C:\...` is treated as a relative path.
"include.tpl": `Host * "config.tpl": `Host *
{{- if or .User.GOOS "none" | eq "windows" }} {{- if or .User.GOOS "none" | eq "windows" }}
Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config" {{- if .User.StepBasePath }}
Include "{{ .User.StepBasePath | replace "\\" "/" | trimPrefix "C:" }}/ssh/includes"
{{- else }} {{- else }}
Include "{{.User.StepPath}}/ssh/config" Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/includes"
{{- end }}
{{- else }}
{{- if .User.StepBasePath }}
Include "{{.User.StepBasePath}}/ssh/includes"
{{- else }}
Include "{{.User.StepPath}}/ssh/includes"
{{- end }}
{{- end }}`, {{- end }}`,
// includes.tpl adds the step ssh config file.
//
// Note: on windows `Include C:\...` is treated as a relative path.
"step_includes.tpl": `{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}`,
// config.tpl is the step ssh config file, it includes the Match rule and // config.tpl is the step ssh config file, it includes the Match rule and
// references the step known_hosts file. // references the step known_hosts file.
// //
// Note: on windows ProxyCommand requires the full path // Note: on windows ProxyCommand requires the full path
"config.tpl": `Match exec "step ssh check-host %h" "step_config.tpl": `Match exec "step ssh check-host{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %h"
{{- if .User.User }} {{- if .User.User }}
User {{.User.User}} User {{.User.User}}
{{- end }} {{- end }}
{{- if or .User.GOOS "none" | eq "windows" }} {{- if or .User.GOOS "none" | eq "windows" }}
UserKnownHostsFile "{{.User.StepPath}}\ssh\known_hosts" UserKnownHostsFile "{{.User.StepPath}}\ssh\known_hosts"
ProxyCommand C:\Windows\System32\cmd.exe /c step ssh proxycommand %r %h %p ProxyCommand C:\Windows\System32\cmd.exe /c step ssh proxycommand{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %r %h %p
{{- else }} {{- else }}
UserKnownHostsFile "{{.User.StepPath}}/ssh/known_hosts" UserKnownHostsFile "{{.User.StepPath}}/ssh/known_hosts"
ProxyCommand step ssh proxycommand %r %h %p ProxyCommand step ssh proxycommand{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %r %h %p
{{- end }} {{- end }}
`, `,