Merge pull request #692 from smallstep/max/context

Context management
This commit is contained in:
Max 2021-11-17 12:06:42 -08:00 committed by GitHub
commit de2ce5cf9f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 263 additions and 133 deletions

View file

@ -58,6 +58,7 @@ jobs:
run: V=1 make ci run: V=1 make ci
- -
name: Codecov name: Codecov
if: matrix.go == '1.17'
uses: codecov/codecov-action@v1.2.1 uses: codecov/codecov-action@v1.2.1
with: with:
file: ./coverage.out # optional file: ./coverage.out # optional

View file

@ -9,7 +9,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
) )
@ -245,7 +245,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
return "" return ""
} }
stepPath := filepath.ToSlash(config.StepPath()) stepPath := filepath.ToSlash(step.Path())
if !strings.HasSuffix(stepPath, "/") { if !strings.HasSuffix(stepPath, "/") {
stepPath += "/" stepPath += "/"
} }
@ -257,7 +257,7 @@ func mustReadFileOrURI(fn string, m map[string][]byte) string {
panic(err) panic(err)
} }
if ok { if ok {
b, err := os.ReadFile(config.StepAbs(fn)) b, err := os.ReadFile(step.Abs(fn))
if err != nil { if err != nil {
panic(errors.Wrapf(err, "error reading %s", fn)) panic(errors.Wrapf(err, "error reading %s", fn))
} }

View file

@ -13,7 +13,7 @@ import (
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
step "go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/linkedca" "go.step.sm/linkedca"
@ -238,6 +238,8 @@ func (a *Authority) RemoveProvisioner(ctx context.Context, id string) error {
return nil return nil
} }
// CreateFirstProvisioner creates and stores the first provisioner when using
// admin database provisioner storage.
func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) { func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (*linkedca.Provisioner, error) {
if password == "" { if password == "" {
pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one") pass, err := ui.PromptPasswordGenerate("Please enter the password to encrypt your first provisioner, leave empty and we'll generate one")
@ -287,6 +289,7 @@ func CreateFirstProvisioner(ctx context.Context, db admin.DB, password string) (
return p, nil return p, nil
} }
// ValidateClaims validates the Claims type.
func ValidateClaims(c *linkedca.Claims) error { func ValidateClaims(c *linkedca.Claims) error {
if c == nil { if c == nil {
return nil return nil
@ -313,6 +316,7 @@ func ValidateClaims(c *linkedca.Claims) error {
return nil return nil
} }
// ValidateDurations validates the Durations type.
func ValidateDurations(d *linkedca.Durations) error { func ValidateDurations(d *linkedca.Durations) error {
var ( var (
err error err error
@ -523,7 +527,7 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.X509.Template != "" { if p.X509.Template != "" {
x509Template.Template = []byte(p.SSH.Template) x509Template.Template = []byte(p.SSH.Template)
} else if p.X509.TemplateFile != "" { } else if p.X509.TemplateFile != "" {
filename := step.StepAbs(p.X509.TemplateFile) filename := step.Abs(p.X509.TemplateFile)
if x509Template.Template, err = os.ReadFile(filename); err != nil { if x509Template.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading x509 template") return nil, nil, errors.Wrap(err, "error reading x509 template")
} }
@ -539,7 +543,7 @@ func provisionerOptionsToLinkedca(p *provisioner.Options) (*linkedca.Template, *
if p.SSH.Template != "" { if p.SSH.Template != "" {
sshTemplate.Template = []byte(p.SSH.Template) sshTemplate.Template = []byte(p.SSH.Template)
} else if p.SSH.TemplateFile != "" { } else if p.SSH.TemplateFile != "" {
filename := step.StepAbs(p.SSH.TemplateFile) filename := step.Abs(p.SSH.TemplateFile)
if sshTemplate.Template, err = os.ReadFile(filename); err != nil { if sshTemplate.Template, err = os.ReadFile(filename); err != nil {
return nil, nil, errors.Wrap(err, "error reading ssh template") return nil, nil, errors.Wrap(err, "error reading ssh template")
} }

View file

@ -101,6 +101,15 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Backwards compatibility for version of the cli older than v0.18.0.
// Before v0.18.0 we were not passing any value for SSHTemplateVersionKey
// from the cli.
if o.Name == "step_includes.tpl" && data[templates.SSHTemplateVersionKey] == "" {
o.Type = templates.File
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/")
}
output = append(output, o) output = append(output, o)
} }
return output, nil return output, nil

View file

@ -501,6 +501,32 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")}, {Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")},
} }
tmplConfigUserIncludes := &templates.Templates{
SSH: &templates.SSHTemplates{
User: []templates.Template{
{Name: "step_includes.tpl", Type: templates.PrependLine, TemplatePath: "./testdata/templates/step_includes.tpl", Path: "${STEPPATH}/ssh/includes", Comment: "#"},
},
},
Data: map[string]interface{}{
"Step": &templates.Step{
SSH: templates.StepSSH{
UserKey: user,
HostKey: host,
},
},
},
}
userOutputEmptyData := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"<no value>/ssh/config\"\n")},
}
userOutputWithoutTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
userOutputWithTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.PrependLine, Comment: "#", Path: "${STEPPATH}/ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
tmplConfigErr := &templates.Templates{ tmplConfigErr := &templates.Templates{
SSH: &templates.SSHTemplates{ SSH: &templates.SSHTemplates{
User: []templates.Template{ User: []templates.Template{
@ -542,6 +568,9 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false}, {"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false},
{"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false}, {"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false},
{"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false}, {"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false},
{"userIncludesEmptyData", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", nil}, userOutputEmptyData, false},
{"userIncludesWithoutTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithoutTemplateVersion, false},
{"userIncludesWithTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step", "StepSSHTemplateVersion": "v2"}}, userOutputWithTemplateVersion, false},
{"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true}, {"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true},
{"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true}, {"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true},
{"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true}, {"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true},

View file

@ -0,0 +1 @@
{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}

View file

@ -28,7 +28,7 @@ import (
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/ca/identity" "github.com/smallstep/certificates/ca/identity"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
@ -237,6 +237,17 @@ func WithTransport(tr http.RoundTripper) ClientOption {
} }
} }
// WithInsecure adds a insecure transport that bypasses TLS verification.
func WithInsecure() ClientOption {
return func(o *clientOptions) error {
o.transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
return nil
}
}
// WithRootFile will create the transport using the given root certificate. It // WithRootFile will create the transport using the given root certificate. It
// will fail if a previous option to create the transport has been configured. // will fail if a previous option to create the transport has been configured.
func WithRootFile(filename string) ClientOption { func WithRootFile(filename string) ClientOption {
@ -1294,7 +1305,7 @@ func createCertificateRequest(commonName string, sans []string, key crypto.Priva
// getRootCAPath returns the path where the root CA is stored based on the // getRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable. // STEPPATH environment variable.
func getRootCAPath() string { func getRootCAPath() string {
return filepath.Join(config.StepPath(), "certs", "root_ca.crt") return filepath.Join(step.Path(), "certs", "root_ca.crt")
} }
func readJSON(r io.ReadCloser, v interface{}) error { func readJSON(r io.ReadCloser, v interface{}) error {

View file

@ -27,21 +27,22 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL {
// $STEPPATH/config/defaults.json and the identity defined in // $STEPPATH/config/defaults.json and the identity defined in
// $STEPPATH/config/identity.json // $STEPPATH/config/identity.json
func LoadClient() (*Client, error) { func LoadClient() (*Client, error) {
b, err := os.ReadFile(DefaultsFile) defaultsFile := DefaultsFile()
b, err := os.ReadFile(defaultsFile)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", DefaultsFile) return nil, errors.Wrapf(err, "error reading %s", defaultsFile)
} }
var defaults defaultsConfig var defaults defaultsConfig
if err := json.Unmarshal(b, &defaults); err != nil { if err := json.Unmarshal(b, &defaults); err != nil {
return nil, errors.Wrapf(err, "error unmarshaling %s", DefaultsFile) return nil, errors.Wrapf(err, "error unmarshaling %s", defaultsFile)
} }
if err := defaults.Validate(); err != nil { if err := defaults.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
caURL, err := url.Parse(defaults.CaURL) caURL, err := url.Parse(defaults.CaURL)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error validating %s", DefaultsFile) return nil, errors.Wrapf(err, "error validating %s", defaultsFile)
} }
if caURL.Scheme == "" { if caURL.Scheme == "" {
caURL.Scheme = "https" caURL.Scheme = "https"
@ -52,7 +53,7 @@ func LoadClient() (*Client, error) {
return nil, err return nil, err
} }
if err := identity.Validate(); err != nil { if err := identity.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", IdentityFile) return nil, errors.Wrapf(err, "error validating %s", IdentityFile())
} }
if kind := identity.Kind(); kind != MutualTLS { if kind := identity.Kind(); kind != MutualTLS {
return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind) return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind)

View file

@ -11,6 +11,12 @@ import (
"testing" "testing"
) )
func returnInput(val string) func() string {
return func() string {
return val
}
}
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
oldIdentityFile := IdentityFile oldIdentityFile := IdentityFile
oldDefaultsFile := DefaultsFile oldDefaultsFile := DefaultsFile
@ -19,8 +25,8 @@ func TestClient(t *testing.T) {
DefaultsFile = oldDefaultsFile DefaultsFile = oldDefaultsFile
}() }()
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
client, err := LoadClient() client, err := LoadClient()
if err != nil { if err != nil {
@ -140,36 +146,36 @@ func TestLoadClient(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{"ok", func() { {"ok", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, expected, false}, }, expected, false},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/missing.json" IdentityFile = returnInput("testdata/config/missing.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/fail.json" IdentityFile = returnInput("testdata/config/fail.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/missing.json" DefaultsFile = returnInput("testdata/config/missing.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/fail.json" DefaultsFile = returnInput("testdata/config/fail.json")
}, nil, true}, }, nil, true},
{"fail ca", func() { {"fail ca", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badca.json" DefaultsFile = returnInput("testdata/config/badca.json")
}, nil, true}, }, nil, true},
{"fail root", func() { {"fail root", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badroot.json" DefaultsFile = returnInput("testdata/config/badroot.json")
}, nil, true}, }, nil, true},
{"fail type", func() { {"fail type", func() {
IdentityFile = "testdata/config/badIdentity.json" IdentityFile = returnInput("testdata/config/badIdentity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -15,7 +15,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
@ -38,11 +38,18 @@ const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims. // DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute const DefaultLeeway = 1 * time.Minute
// IdentityFile contains the location of the identity file. var (
var IdentityFile = filepath.Join(config.StepPath(), "config", "identity.json") identityDir = step.IdentityPath
configDir = step.ConfigPath
// DefaultsFile contains the location of the defaults file. // IdentityFile contains a pointer to a function that outputs the location of
var DefaultsFile = filepath.Join(config.StepPath(), "config", "defaults.json") // the identity file.
IdentityFile = step.IdentityFile
// DefaultsFile contains a prointer a function that outputs the location of the
// defaults configuration file.
DefaultsFile = step.DefaultsFile
)
// Identity represents the identity file that can be used to authenticate with // Identity represents the identity file that can be used to authenticate with
// the CA. // the CA.
@ -73,23 +80,17 @@ func LoadIdentity(filename string) (*Identity, error) {
// LoadDefaultIdentity loads the default identity. // LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) { func LoadDefaultIdentity() (*Identity, error) {
return LoadIdentity(IdentityFile) return LoadIdentity(IdentityFile())
} }
// configDir and identityDir are used in WriteDefaultIdentity for testing
// purposes.
var (
configDir = filepath.Join(config.StepPath(), "config")
identityDir = filepath.Join(config.StepPath(), "identity")
)
// WriteDefaultIdentity writes the given certificates and key and the // WriteDefaultIdentity writes the given certificates and key and the
// identity.json pointing to the new files. // identity.json pointing to the new files.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
if err := os.MkdirAll(configDir, 0700); err != nil { if err := os.MkdirAll(configDir(), 0700); err != nil {
return errors.Wrap(err, "error creating config directory") return errors.Wrap(err, "error creating config directory")
} }
identityDir := identityDir()
if err := os.MkdirAll(identityDir, 0700); err != nil { if err := os.MkdirAll(identityDir, 0700); err != nil {
return errors.Wrap(err, "error creating identity directory") return errors.Wrap(err, "error creating identity directory")
} }
@ -126,7 +127,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
}); err != nil { }); err != nil {
return errors.Wrap(err, "error writing identity json") return errors.Wrap(err, "error writing identity json")
} }
if err := os.WriteFile(IdentityFile, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }
@ -135,7 +136,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
// WriteIdentityCertificate writes the identity certificate to disk. // WriteIdentityCertificate writes the identity certificate to disk.
func WriteIdentityCertificate(certChain []api.Certificate) error { func WriteIdentityCertificate(certChain []api.Certificate) error {
filename := filepath.Join(identityDir, "identity.crt") filename := filepath.Join(identityDir(), "identity.crt")
return writeCertificate(filename, certChain) return writeCertificate(filename, certChain)
} }
@ -318,7 +319,7 @@ func (i *Identity) Renew(client Renewer) error {
return errors.Wrap(err, "error encoding identity certificate") return errors.Wrap(err, "error encoding identity certificate")
} }
} }
certFilename := filepath.Join(identityDir, "identity.crt") certFilename := filepath.Join(identityDir(), "identity.crt")
if err := os.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { if err := os.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }

View file

@ -33,9 +33,9 @@ func TestLoadDefaultIdentity(t *testing.T) {
want *Identity want *Identity
wantErr bool wantErr bool
}{ }{
{"ok", func() { IdentityFile = "testdata/config/identity.json" }, expected, false}, {"ok", func() { IdentityFile = returnInput("testdata/config/identity.json") }, expected, false},
{"fail read", func() { IdentityFile = "testdata/config/missing.json" }, nil, true}, {"fail read", func() { IdentityFile = returnInput("testdata/config/missing.json") }, nil, true},
{"fail unmarshal", func() { IdentityFile = "testdata/config/fail.json" }, nil, true}, {"fail unmarshal", func() { IdentityFile = returnInput("testdata/config/fail.json") }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -217,9 +217,9 @@ func TestWriteDefaultIdentity(t *testing.T) {
certChain = append(certChain, api.Certificate{Certificate: c}) certChain = append(certChain, api.Certificate{Certificate: c})
} }
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(tmpDir, "config", "identity.json") IdentityFile = returnInput(filepath.Join(tmpDir, "config", "identity.json"))
type args struct { type args struct {
certChain []api.Certificate certChain []api.Certificate
@ -233,27 +233,27 @@ func TestWriteDefaultIdentity(t *testing.T) {
}{ }{
{"ok", func() {}, args{certChain, key}, false}, {"ok", func() {}, args{certChain, key}, false},
{"fail mkdir config", func() { {"fail mkdir config", func() {
configDir = filepath.Join(tmpDir, "identity", "identity.crt") configDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail mkdir identity", func() { {"fail mkdir identity", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity", "identity.crt") identityDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail certificate", func() { {"fail certificate", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail key", func() { {"fail key", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, "badKey"}, true}, }, args{certChain, "badKey"}, true},
{"fail write identity", func() { {"fail write identity", func() {
configDir = filepath.Join(tmpDir, "bad-dir") configDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(configDir, "identity.json") IdentityFile = returnInput(filepath.Join(configDir(), "identity.json"))
os.MkdirAll(configDir, 0600) os.MkdirAll(configDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
} }
@ -377,7 +377,7 @@ func TestIdentity_Renew(t *testing.T) {
} }
oldIdentityDir := identityDir oldIdentityDir := identityDir
identityDir = "testdata/identity" identityDir = returnInput("testdata/identity")
defer func() { defer func() {
identityDir = oldIdentityDir identityDir = oldIdentityDir
os.RemoveAll(tmpDir) os.RemoveAll(tmpDir)
@ -432,8 +432,8 @@ func TestIdentity_Renew(t *testing.T) {
{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true}, {"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true}, {"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
{"fail write identity", func() { {"fail write identity", func() {
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true}, }, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -21,7 +21,7 @@ import (
"github.com/urfave/cli" "github.com/urfave/cli"
"go.step.sm/cli-utils/command" "go.step.sm/cli-utils/command"
"go.step.sm/cli-utils/command/version" "go.step.sm/cli-utils/command/version"
"go.step.sm/cli-utils/config" "go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/cli-utils/usage" "go.step.sm/cli-utils/usage"
@ -49,7 +49,7 @@ var (
) )
func init() { func init() {
config.Set("Smallstep CA", Version, BuildTime) step.Set("Smallstep CA", Version, BuildTime)
authority.GlobalVersion.Version = Version authority.GlobalVersion.Version = Version
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
} }
@ -115,7 +115,7 @@ func main() {
app := cli.NewApp() app := cli.NewApp()
app.Name = "step-ca" app.Name = "step-ca"
app.HelpName = "step-ca" app.HelpName = "step-ca"
app.Version = config.Version() app.Version = step.Version()
app.Usage = "an online certificate authority for secure automated certificate management" app.Usage = "an online certificate authority for secure automated certificate management"
app.UsageText = `**step-ca** <config> [**--password-file**=<file>] app.UsageText = `**step-ca** <config> [**--password-file**=<file>]
[**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>] [**--ssh-host-password-file**=<file>] [**--ssh-user-password-file**=<file>]

13
go.mod
View file

@ -29,10 +29,10 @@ require (
github.com/rs/xid v1.2.1 github.com/rs/xid v1.2.1
github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus v1.4.2
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262
github.com/smallstep/nosql v0.3.8 github.com/smallstep/nosql v0.3.9
github.com/urfave/cli v1.22.4 github.com/urfave/cli v1.22.4
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352 go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352
go.step.sm/cli-utils v0.6.2 go.step.sm/cli-utils v0.7.0
go.step.sm/crypto v0.13.0 go.step.sm/crypto v0.13.0
go.step.sm/linkedca v0.7.0 go.step.sm/linkedca v0.7.0
golang.org/x/crypto v0.0.0-20210915214749-c084706c2272 golang.org/x/crypto v0.0.0-20210915214749-c084706c2272
@ -44,7 +44,8 @@ require (
gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/square/go-jose.v2 v2.6.0
) )
// replace github.com/smallstep/nosql => ../nosql //replace github.com/smallstep/nosql => ../nosql
// replace go.step.sm/crypto => ../crypto
// replace go.step.sm/cli-utils => ../cli-utils //replace go.step.sm/crypto => ../crypto
// replace go.step.sm/linkedca => ../linkedca
//replace go.step.sm/cli-utils => ../cli-utils

8
go.sum
View file

@ -494,8 +494,8 @@ github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6Mwd
github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE= github.com/smallstep/assert v0.0.0-20180720014142-de77670473b5/go.mod h1:TC9A4+RjIOS+HyTH7wG17/gSqVv95uDw2J64dQZx7RE=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262 h1:unQFBIznI+VYD1/1fApl1A+9VcBk+9dcqGfnePY87LY=
github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc= github.com/smallstep/assert v0.0.0-20200723003110-82e2b9b3b262/go.mod h1:MyOHs9Po2fbM1LHej6sBUT8ozbxmMOFG+E+rx/GSGuc=
github.com/smallstep/nosql v0.3.8 h1:1/EWUbbEdz9ai0g9Fd09VekVjtxp+5+gIHpV2PdwW3o= github.com/smallstep/nosql v0.3.9 h1:YPy5PR3PXClqmpFaVv0wfXDXDc7NXGBE1auyU2c87dc=
github.com/smallstep/nosql v0.3.8/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs= github.com/smallstep/nosql v0.3.9/go.mod h1:X2qkYpNcW3yjLUvhEHfgGfClpKbFPapewvx7zo4TOFs=
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM= github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
@ -559,8 +559,8 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M=
go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E=
go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI=
go.step.sm/cli-utils v0.6.2 h1:ofa3G/EqE3dTDXmzoXHDZr18qJZoFsKSzbzuF+mxuZU= go.step.sm/cli-utils v0.7.0 h1:2GvY5Muid1yzp7YQbfCCS+gK3q7zlHjjLL5Z0DXz8ds=
go.step.sm/cli-utils v0.6.2/go.mod h1:0tZ8F2QwLgD6KbKj4nrQZhMakTasEAnOcW3Ekc5pnrA= go.step.sm/cli-utils v0.7.0/go.mod h1:Ur6bqA/yl636kCUJbp30J7Unv5JJ226eW2KqXPDwF/E=
go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0= go.step.sm/crypto v0.9.0/go.mod h1:+CYG05Mek1YDqi5WK0ERc6cOpKly2i/a5aZmU1sfGj0=
go.step.sm/crypto v0.13.0 h1:mQuP9Uu2FNmqCJNO0OTbvolnYXzONy4wdUBtUVcP1s8= go.step.sm/crypto v0.13.0 h1:mQuP9Uu2FNmqCJNO0OTbvolnYXzONy4wdUBtUVcP1s8=
go.step.sm/crypto v0.13.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg= go.step.sm/crypto v0.13.0/go.mod h1:5YzQ85BujYBu6NH18jw7nFjwuRnDch35nLzH0ES5sKg=

View file

@ -29,9 +29,9 @@ import (
"github.com/smallstep/certificates/kms" "github.com/smallstep/certificates/kms"
kmsapi "github.com/smallstep/certificates/kms/apiv1" kmsapi "github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/errs"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
"go.step.sm/cli-utils/ui" "go.step.sm/cli-utils/ui"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
@ -87,44 +87,50 @@ const (
) )
// GetDBPath returns the path where the file-system persistence is stored // GetDBPath returns the path where the file-system persistence is stored
// based on the STEPPATH environment variable. // based on the $(step path).
func GetDBPath() string { func GetDBPath() string {
return filepath.Join(config.StepPath(), dbPath) return filepath.Join(step.Path(), dbPath)
} }
// GetConfigPath returns the directory where the configuration files are stored // GetConfigPath returns the directory where the configuration files are stored
// based on the STEPPATH environment variable. // based on the $(step path).
func GetConfigPath() string { func GetConfigPath() string {
return filepath.Join(config.StepPath(), configPath) return filepath.Join(step.Path(), configPath)
}
// GetProfileConfigPath returns the directory where the profile configuration
// files are stored based on the $(step path).
func GetProfileConfigPath() string {
return filepath.Join(step.ProfilePath(), configPath)
} }
// GetPublicPath returns the directory where the public keys are stored based on // GetPublicPath returns the directory where the public keys are stored based on
// the STEPPATH environment variable. // the $(step path).
func GetPublicPath() string { func GetPublicPath() string {
return filepath.Join(config.StepPath(), publicPath) return filepath.Join(step.Path(), publicPath)
} }
// GetSecretsPath returns the directory where the private keys are stored based // GetSecretsPath returns the directory where the private keys are stored based
// on the STEPPATH environment variable. // on the $(step path).
func GetSecretsPath() string { func GetSecretsPath() string {
return filepath.Join(config.StepPath(), privatePath) return filepath.Join(step.Path(), privatePath)
} }
// GetRootCAPath returns the path where the root CA is stored based on the // GetRootCAPath returns the path where the root CA is stored based on the
// STEPPATH environment variable. // $(step path).
func GetRootCAPath() string { func GetRootCAPath() string {
return filepath.Join(config.StepPath(), publicPath, "root_ca.crt") return filepath.Join(step.Path(), publicPath, "root_ca.crt")
} }
// GetOTTKeyPath returns the path where the one-time token key is stored based // GetOTTKeyPath returns the path where the one-time token key is stored based
// on the STEPPATH environment variable. // on the $(step path).
func GetOTTKeyPath() string { func GetOTTKeyPath() string {
return filepath.Join(config.StepPath(), privatePath, "ott_key") return filepath.Join(step.Path(), privatePath, "ott_key")
} }
// GetTemplatesPath returns the path where the templates are stored. // GetTemplatesPath returns the path where the templates are stored.
func GetTemplatesPath() string { func GetTemplatesPath() string {
return filepath.Join(config.StepPath(), templatesPath) return filepath.Join(step.Path(), templatesPath)
} }
// GetProvisioners returns the map of provisioners on the given CA. // GetProvisioners returns the map of provisioners on the given CA.
@ -293,6 +299,7 @@ type PKI struct {
keyManager kmsapi.KeyManager keyManager kmsapi.KeyManager
config string config string
defaults string defaults string
profileDefaults string
ottPublicKey *jose.JSONWebKey ottPublicKey *jose.JSONWebKey
ottPrivateKey *jose.JSONWebEncryption ottPrivateKey *jose.JSONWebEncryption
options *options options *options
@ -300,6 +307,7 @@ type PKI struct {
// New creates a new PKI configuration. // New creates a new PKI configuration.
func New(o apiv1.Options, opts ...Option) (*PKI, error) { func New(o apiv1.Options, opts ...Option) (*PKI, error) {
currentCtx := step.Contexts().GetCurrent()
caService, err := cas.New(context.Background(), o) caService, err := cas.New(context.Background(), o)
if err != nil { if err != nil {
return nil, err return nil, err
@ -358,6 +366,9 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
cfg = GetConfigPath() cfg = GetConfigPath()
// Create directories // Create directories
dirs := []string{public, private, cfg, GetTemplatesPath()} dirs := []string{public, private, cfg, GetTemplatesPath()}
if currentCtx != nil {
dirs = append(dirs, GetProfileConfigPath())
}
for _, name := range dirs { for _, name := range dirs {
if _, err := os.Stat(name); os.IsNotExist(err) { if _, err := os.Stat(name); os.IsNotExist(err) {
if err = os.MkdirAll(name, 0700); err != nil { if err = os.MkdirAll(name, 0700); err != nil {
@ -415,6 +426,10 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) {
if p.defaults, err = getPath(cfg, "defaults.json"); err != nil { if p.defaults, err = getPath(cfg, "defaults.json"); err != nil {
return nil, err return nil, err
} }
if currentCtx != nil {
p.profileDefaults = currentCtx.ProfileDefaultsFile()
}
if p.config, err = getPath(cfg, "ca.json"); err != nil { if p.config, err = getPath(cfg, "ca.json"); err != nil {
return nil, err return nil, err
} }
@ -944,6 +959,18 @@ func (p *PKI) Save(opt ...ConfigOption) error {
if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil { if err = fileutil.WriteFile(p.defaults, b, 0644); err != nil {
return errs.FileError(err, p.defaults) return errs.FileError(err, p.defaults)
} }
// If we're using contexts then write a blank object to the default profile
// configuration location.
if p.profileDefaults != "" {
if _, err := os.Stat(p.profileDefaults); os.IsNotExist(err) {
// Write with 0600 to be consistent with directories structure.
if err = fileutil.WriteFile(p.profileDefaults, []byte("{}"), 0600); err != nil {
return errs.FileError(err, p.profileDefaults)
}
} else if err != nil {
return errs.FileError(err, p.profileDefaults)
}
}
// Generate and write templates // Generate and write templates
if err := generateTemplates(cfg.Templates); err != nil { if err := generateTemplates(cfg.Templates); err != nil {
@ -958,6 +985,9 @@ func (p *PKI) Save(opt ...ConfigOption) error {
} }
ui.PrintSelected("Default configuration", p.defaults) ui.PrintSelected("Default configuration", p.defaults)
if p.profileDefaults != "" {
ui.PrintSelected("Default profile configuration", p.profileDefaults)
}
ui.PrintSelected("Certificate Authority configuration", p.config) ui.PrintSelected("Certificate Authority configuration", p.config)
if p.options.deploymentType != LinkedDeployment { if p.options.deploymentType != LinkedDeployment {
ui.Println() ui.Println()

View file

@ -6,9 +6,9 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/errs" "go.step.sm/cli-utils/errs"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
) )
// getTemplates returns all the templates enabled // getTemplates returns all the templates enabled
@ -44,7 +44,7 @@ func generateTemplates(t *templates.Templates) error {
if !ok { if !ok {
return errors.Errorf("template %s does not exists", t.Name) return errors.Errorf("template %s does not exists", t.Name)
} }
if err := fileutil.WriteFile(config.StepAbs(t.TemplatePath), []byte(data), 0644); err != nil { if err := fileutil.WriteFile(step.Abs(t.TemplatePath), []byte(data), 0644); err != nil {
return err return err
} }
} }
@ -53,7 +53,7 @@ func generateTemplates(t *templates.Templates) error {
if !ok { if !ok {
return errors.Errorf("template %s does not exists", t.Name) return errors.Errorf("template %s does not exists", t.Name)
} }
if err := fileutil.WriteFile(config.StepAbs(t.TemplatePath), []byte(data), 0644); err != nil { if err := fileutil.WriteFile(step.Abs(t.TemplatePath), []byte(data), 0644); err != nil {
return err return err
} }
} }

View file

@ -9,8 +9,8 @@ import (
"github.com/Masterminds/sprig/v3" "github.com/Masterminds/sprig/v3"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/cli-utils/config"
"go.step.sm/cli-utils/fileutil" "go.step.sm/cli-utils/fileutil"
"go.step.sm/cli-utils/step"
) )
// TemplateType defines how a template will be written in disk. // TemplateType defines how a template will be written in disk.
@ -19,6 +19,9 @@ type TemplateType string
const ( const (
// Snippet will mark a template as a part of a file. // Snippet will mark a template as a part of a file.
Snippet TemplateType = "snippet" Snippet TemplateType = "snippet"
// PrependLine is a template for prepending a single line to a file. If the
// line already exists in the file it will be removed first.
PrependLine TemplateType = "prepend-line"
// File will mark a templates as a full file. // File will mark a templates as a full file.
File TemplateType = "file" File TemplateType = "file"
// Directory will mark a template as a directory. // Directory will mark a template as a directory.
@ -98,7 +101,7 @@ func (t *SSHTemplates) Validate() (err error) {
return return
} }
// Template represents on template file. // Template represents a template file.
type Template struct { type Template struct {
*template.Template *template.Template
Name string `json:"name"` Name string `json:"name"`
@ -117,8 +120,8 @@ func (t *Template) Validate() error {
return nil return nil
case t.Name == "": case t.Name == "":
return errors.New("template name cannot be empty") return errors.New("template name cannot be empty")
case t.Type != Snippet && t.Type != File && t.Type != Directory: case t.Type != Snippet && t.Type != File && t.Type != Directory && t.Type != PrependLine:
return errors.Errorf("invalid template type %s, it must be %s, %s, or %s", t.Type, Snippet, File, Directory) return errors.Errorf("invalid template type %s, it must be %s, %s, %s, or %s", t.Type, Snippet, PrependLine, File, Directory)
case t.TemplatePath == "" && t.Type != Directory && len(t.Content) == 0: case t.TemplatePath == "" && t.Type != Directory && len(t.Content) == 0:
return errors.New("template template cannot be empty") return errors.New("template template cannot be empty")
case t.TemplatePath != "" && t.Type == Directory: case t.TemplatePath != "" && t.Type == Directory:
@ -131,7 +134,7 @@ func (t *Template) Validate() error {
if t.TemplatePath != "" { if t.TemplatePath != "" {
// Check for file // Check for file
st, err := os.Stat(config.StepAbs(t.TemplatePath)) st, err := os.Stat(step.Abs(t.TemplatePath))
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", t.TemplatePath) return errors.Wrapf(err, "error reading %s", t.TemplatePath)
} }
@ -165,7 +168,7 @@ func (t *Template) Load() error {
if t.Template == nil && t.Type != Directory { if t.Template == nil && t.Type != Directory {
switch { switch {
case t.TemplatePath != "": case t.TemplatePath != "":
filename := config.StepAbs(t.TemplatePath) filename := step.Abs(t.TemplatePath)
b, err := os.ReadFile(filename) b, err := os.ReadFile(filename)
if err != nil { if err != nil {
return errors.Wrapf(err, "error reading %s", filename) return errors.Wrapf(err, "error reading %s", filename)
@ -246,7 +249,10 @@ type Output struct {
// Write writes the Output to the filesystem as a directory, file or snippet. // Write writes the Output to the filesystem as a directory, file or snippet.
func (o *Output) Write() error { func (o *Output) Write() error {
path := config.StepAbs(o.Path) // Replace ${STEPPATH} with the base step path.
o.Path = strings.ReplaceAll(o.Path, "${STEPPATH}", step.BasePath())
path := step.Abs(o.Path)
if o.Type == Directory { if o.Type == Directory {
return mkdir(path, 0700) return mkdir(path, 0700)
} }
@ -256,11 +262,17 @@ func (o *Output) Write() error {
return err return err
} }
if o.Type == File { switch o.Type {
case File:
return fileutil.WriteFile(path, o.Content, 0600) return fileutil.WriteFile(path, o.Content, 0600)
} case Snippet:
return fileutil.WriteSnippet(path, o.Content, 0600) return fileutil.WriteSnippet(path, o.Content, 0600)
case PrependLine:
return fileutil.PrependLine(path, o.Content, 0600)
default:
// Default to using a Snippet type if the type is not known.
return fileutil.WriteSnippet(path, o.Content, 0600)
}
} }
func mkdir(path string, perm os.FileMode) error { func mkdir(path string, perm os.FileMode) error {

View file

@ -4,6 +4,10 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// SSHTemplateVersionKey is a key that can be submitted by a client to select
// the template version that will be returned by the server.
var SSHTemplateVersionKey = "StepSSHTemplateVersion"
// Step represents the default variables available in the CA. // Step represents the default variables available in the CA.
type Step struct { type Step struct {
SSH StepSSH SSH StepSSH
@ -22,16 +26,23 @@ type StepSSH struct {
var DefaultSSHTemplates = SSHTemplates{ var DefaultSSHTemplates = SSHTemplates{
User: []Template{ User: []Template{
{ {
Name: "include.tpl", Name: "config.tpl",
Type: Snippet, Type: Snippet,
TemplatePath: "templates/ssh/include.tpl", TemplatePath: "templates/ssh/config.tpl",
Path: "~/.ssh/config", Path: "~/.ssh/config",
Comment: "#", Comment: "#",
}, },
{ {
Name: "config.tpl", Name: "step_includes.tpl",
Type: PrependLine,
TemplatePath: "templates/ssh/step_includes.tpl",
Path: "${STEPPATH}/ssh/includes",
Comment: "#",
},
{
Name: "step_config.tpl",
Type: File, Type: File,
TemplatePath: "templates/ssh/config.tpl", TemplatePath: "templates/ssh/step_config.tpl",
Path: "ssh/config", Path: "ssh/config",
Comment: "#", Comment: "#",
}, },
@ -64,30 +75,43 @@ var DefaultSSHTemplates = SSHTemplates{
// DefaultSSHTemplateData contains the data of the default templates used on ssh. // DefaultSSHTemplateData contains the data of the default templates used on ssh.
var DefaultSSHTemplateData = map[string]string{ var DefaultSSHTemplateData = map[string]string{
// include.tpl adds the step ssh config file. // base_config.tpl adds the step ssh config file.
// //
// Note: on windows `Include C:\...` is treated as a relative path. // Note: on windows `Include C:\...` is treated as a relative path.
"include.tpl": `Host * "config.tpl": `Host *
{{- if or .User.GOOS "none" | eq "windows" }} {{- if or .User.GOOS "none" | eq "windows" }}
Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config" {{- if .User.StepBasePath }}
Include "{{ .User.StepBasePath | replace "\\" "/" | trimPrefix "C:" }}/ssh/includes"
{{- else }} {{- else }}
Include "{{.User.StepPath}}/ssh/config" Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/includes"
{{- end }}
{{- else }}
{{- if .User.StepBasePath }}
Include "{{.User.StepBasePath}}/ssh/includes"
{{- else }}
Include "{{.User.StepPath}}/ssh/includes"
{{- end }}
{{- end }}`, {{- end }}`,
// includes.tpl adds the step ssh config file.
//
// Note: on windows `Include C:\...` is treated as a relative path.
"step_includes.tpl": `{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}`,
// config.tpl is the step ssh config file, it includes the Match rule and // config.tpl is the step ssh config file, it includes the Match rule and
// references the step known_hosts file. // references the step known_hosts file.
// //
// Note: on windows ProxyCommand requires the full path // Note: on windows ProxyCommand requires the full path
"config.tpl": `Match exec "step ssh check-host %h" "step_config.tpl": `Match exec "step ssh check-host{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %h"
{{- if .User.User }} {{- if .User.User }}
User {{.User.User}} User {{.User.User}}
{{- end }} {{- end }}
{{- if or .User.GOOS "none" | eq "windows" }} {{- if or .User.GOOS "none" | eq "windows" }}
UserKnownHostsFile "{{.User.StepPath}}\ssh\known_hosts" UserKnownHostsFile "{{.User.StepPath}}\ssh\known_hosts"
ProxyCommand C:\Windows\System32\cmd.exe /c step ssh proxycommand %r %h %p ProxyCommand C:\Windows\System32\cmd.exe /c step ssh proxycommand{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %r %h %p
{{- else }} {{- else }}
UserKnownHostsFile "{{.User.StepPath}}/ssh/known_hosts" UserKnownHostsFile "{{.User.StepPath}}/ssh/known_hosts"
ProxyCommand step ssh proxycommand %r %h %p ProxyCommand step ssh proxycommand{{- if .User.Context }} --context {{ .User.Context }}{{- end }} %r %h %p
{{- end }} {{- end }}
`, `,