SSH backwards compat updates

- use existence of new value in data map as boolean
- add tests for backwards and forwards compatibility
- fix old tests that used static dir locations
This commit is contained in:
max furman 2021-11-15 15:32:07 -08:00
parent d37313bef4
commit a7d144996f
7 changed files with 105 additions and 61 deletions

View file

@ -102,10 +102,14 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
return nil, err
}
// Backwards compatibility for version of the cli older than v0.18.0
if o.Name == "step_includes.tpl" && (data == nil || data[templates.SSHTemplateVersionKey] != "v2") {
o.Type = templates.File
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/")
// Backwards compatibility for version of the cli older than v0.18.0.
// Before v0.18.0 we were not passing any value for SSHTemplateVersionKey
// from the cli.
if o.Name == "step_includes.tpl" {
if val, ok := data[templates.SSHTemplateVersionKey]; !ok || val == "" {
o.Type = templates.File
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/")
}
}
output = append(output, o)

View file

@ -501,6 +501,32 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")},
}
tmplConfigUserIncludes := &templates.Templates{
SSH: &templates.SSHTemplates{
User: []templates.Template{
{Name: "step_includes.tpl", Type: templates.PrependLine, TemplatePath: "./testdata/templates/step_includes.tpl", Path: "${STEPPATH}/ssh/includes", Comment: "#"},
},
},
Data: map[string]interface{}{
"Step": &templates.Step{
SSH: templates.StepSSH{
UserKey: user,
HostKey: host,
},
},
},
}
userOutputEmptyData := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"<no value>/ssh/config\"\n")},
}
userOutputWithoutTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
userOutputWithTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.PrependLine, Comment: "#", Path: "${STEPPATH}/ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
tmplConfigErr := &templates.Templates{
SSH: &templates.SSHTemplates{
User: []templates.Template{
@ -542,6 +568,9 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false},
{"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false},
{"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false},
{"userIncludesEmptyData", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", nil}, userOutputEmptyData, false},
{"userIncludesWithoutTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithoutTemplateVersion, false},
{"userIncludesWithTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step", "StepSSHTemplateVersion": "v2"}}, userOutputWithTemplateVersion, false},
{"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true},
{"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true},
{"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true},

View file

@ -0,0 +1 @@
{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}

View file

@ -10,7 +10,6 @@ import (
"net/url"
"github.com/pkg/errors"
"go.step.sm/cli-utils/step"
)
// Client wraps http.Client with a transport using the step root and identity.
@ -28,7 +27,7 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL {
// $STEPPATH/config/defaults.json and the identity defined in
// $STEPPATH/config/identity.json
func LoadClient() (*Client, error) {
defaultsFile := step.DefaultsFile()
defaultsFile := DefaultsFile()
b, err := ioutil.ReadFile(defaultsFile)
if err != nil {
return nil, errors.Wrapf(err, "error reading %s", defaultsFile)
@ -54,7 +53,7 @@ func LoadClient() (*Client, error) {
return nil, err
}
if err := identity.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", step.IdentityFile())
return nil, errors.Wrapf(err, "error validating %s", IdentityFile())
}
if kind := identity.Kind(); kind != MutualTLS {
return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind)

View file

@ -11,6 +11,12 @@ import (
"testing"
)
func returnInput(val string) func() string {
return func() string {
return val
}
}
func TestClient(t *testing.T) {
oldIdentityFile := IdentityFile
oldDefaultsFile := DefaultsFile
@ -19,8 +25,8 @@ func TestClient(t *testing.T) {
DefaultsFile = oldDefaultsFile
}()
IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/defaults.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
client, err := LoadClient()
if err != nil {
@ -140,36 +146,36 @@ func TestLoadClient(t *testing.T) {
wantErr bool
}{
{"ok", func() {
IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/defaults.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, expected, false},
{"fail identity", func() {
IdentityFile = "testdata/config/missing.json"
DefaultsFile = "testdata/config/defaults.json"
IdentityFile = returnInput("testdata/config/missing.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true},
{"fail identity", func() {
IdentityFile = "testdata/config/fail.json"
DefaultsFile = "testdata/config/defaults.json"
IdentityFile = returnInput("testdata/config/fail.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true},
{"fail defaults", func() {
IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/missing.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/missing.json")
}, nil, true},
{"fail defaults", func() {
IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/fail.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/fail.json")
}, nil, true},
{"fail ca", func() {
IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/badca.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/badca.json")
}, nil, true},
{"fail root", func() {
IdentityFile = "testdata/config/identity.json"
DefaultsFile = "testdata/config/badroot.json"
IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = returnInput("testdata/config/badroot.json")
}, nil, true},
{"fail type", func() {
IdentityFile = "testdata/config/badIdentity.json"
DefaultsFile = "testdata/config/defaults.json"
IdentityFile = returnInput("testdata/config/badIdentity.json")
DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true},
}
for _, tt := range tests {

View file

@ -39,6 +39,19 @@ const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute
var (
identityDir = step.IdentityPath
configDir = step.ConfigPath
// IdentityFile contains a pointer to a function that outputs the location of
// the identity file.
IdentityFile = step.IdentityFile
// DefaultsFile contains a prointer a function that outputs the location of the
// defaults configuration file.
DefaultsFile = step.DefaultsFile
)
// Identity represents the identity file that can be used to authenticate with
// the CA.
type Identity struct {
@ -68,25 +81,17 @@ func LoadIdentity(filename string) (*Identity, error) {
// LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) {
return LoadIdentity(step.IdentityFile())
}
func profileConfigDir() string {
return filepath.Join(step.Path(), "config")
}
func profileIdentityDir() string {
return filepath.Join(step.Path(), "identity")
return LoadIdentity(IdentityFile())
}
// WriteDefaultIdentity writes the given certificates and key and the
// identity.json pointing to the new files.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
if err := os.MkdirAll(profileConfigDir(), 0700); err != nil {
if err := os.MkdirAll(configDir(), 0700); err != nil {
return errors.Wrap(err, "error creating config directory")
}
identityDir := profileIdentityDir()
identityDir := identityDir()
if err := os.MkdirAll(identityDir, 0700); err != nil {
return errors.Wrap(err, "error creating identity directory")
}
@ -123,7 +128,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
}); err != nil {
return errors.Wrap(err, "error writing identity json")
}
if err := ioutil.WriteFile(step.IdentityFile(), buf.Bytes(), 0600); err != nil {
if err := ioutil.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
}
@ -132,7 +137,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
// WriteIdentityCertificate writes the identity certificate to disk.
func WriteIdentityCertificate(certChain []api.Certificate) error {
filename := filepath.Join(profileIdentityDir(), "identity.crt")
filename := filepath.Join(identityDir(), "identity.crt")
return writeCertificate(filename, certChain)
}
@ -315,7 +320,7 @@ func (i *Identity) Renew(client Renewer) error {
return errors.Wrap(err, "error encoding identity certificate")
}
}
certFilename := filepath.Join(profileIdentityDir(), "identity.crt")
certFilename := filepath.Join(identityDir(), "identity.crt")
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate")
}

View file

@ -33,9 +33,9 @@ func TestLoadDefaultIdentity(t *testing.T) {
want *Identity
wantErr bool
}{
{"ok", func() { IdentityFile = "testdata/config/identity.json" }, expected, false},
{"fail read", func() { IdentityFile = "testdata/config/missing.json" }, nil, true},
{"fail unmarshal", func() { IdentityFile = "testdata/config/fail.json" }, nil, true},
{"ok", func() { IdentityFile = returnInput("testdata/config/identity.json") }, expected, false},
{"fail read", func() { IdentityFile = returnInput("testdata/config/missing.json") }, nil, true},
{"fail unmarshal", func() { IdentityFile = returnInput("testdata/config/fail.json") }, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -217,9 +217,9 @@ func TestWriteDefaultIdentity(t *testing.T) {
certChain = append(certChain, api.Certificate{Certificate: c})
}
configDir = filepath.Join(tmpDir, "config")
identityDir = filepath.Join(tmpDir, "identity")
IdentityFile = filepath.Join(tmpDir, "config", "identity.json")
configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = returnInput(filepath.Join(tmpDir, "config", "identity.json"))
type args struct {
certChain []api.Certificate
@ -233,27 +233,27 @@ func TestWriteDefaultIdentity(t *testing.T) {
}{
{"ok", func() {}, args{certChain, key}, false},
{"fail mkdir config", func() {
configDir = filepath.Join(tmpDir, "identity", "identity.crt")
identityDir = filepath.Join(tmpDir, "identity")
configDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, key}, true},
{"fail mkdir identity", func() {
configDir = filepath.Join(tmpDir, "config")
identityDir = filepath.Join(tmpDir, "identity", "identity.crt")
configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
}, args{certChain, key}, true},
{"fail certificate", func() {
configDir = filepath.Join(tmpDir, "config")
identityDir = filepath.Join(tmpDir, "bad-dir")
os.MkdirAll(identityDir, 0600)
configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir(), 0600)
}, args{certChain, key}, true},
{"fail key", func() {
configDir = filepath.Join(tmpDir, "config")
identityDir = filepath.Join(tmpDir, "identity")
configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, "badKey"}, true},
{"fail write identity", func() {
configDir = filepath.Join(tmpDir, "bad-dir")
identityDir = filepath.Join(tmpDir, "identity")
IdentityFile = filepath.Join(configDir, "identity.json")
os.MkdirAll(configDir, 0600)
configDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = returnInput(filepath.Join(configDir(), "identity.json"))
os.MkdirAll(configDir(), 0600)
}, args{certChain, key}, true},
}
@ -377,7 +377,7 @@ func TestIdentity_Renew(t *testing.T) {
}
oldIdentityDir := identityDir
identityDir = "testdata/identity"
identityDir = returnInput("testdata/identity")
defer func() {
identityDir = oldIdentityDir
os.RemoveAll(tmpDir)
@ -432,8 +432,8 @@ func TestIdentity_Renew(t *testing.T) {
{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
{"fail write identity", func() {
identityDir = filepath.Join(tmpDir, "bad-dir")
os.MkdirAll(identityDir, 0600)
identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir(), 0600)
}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
}
for _, tt := range tests {