SSH backwards compat updates

- use existence of new value in data map as boolean
- add tests for backwards and forwards compatibility
- fix old tests that used static dir locations
This commit is contained in:
max furman 2021-11-15 15:32:07 -08:00
parent d37313bef4
commit a7d144996f
7 changed files with 105 additions and 61 deletions

View file

@ -102,10 +102,14 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
return nil, err return nil, err
} }
// Backwards compatibility for version of the cli older than v0.18.0 // Backwards compatibility for version of the cli older than v0.18.0.
if o.Name == "step_includes.tpl" && (data == nil || data[templates.SSHTemplateVersionKey] != "v2") { // Before v0.18.0 we were not passing any value for SSHTemplateVersionKey
o.Type = templates.File // from the cli.
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/") if o.Name == "step_includes.tpl" {
if val, ok := data[templates.SSHTemplateVersionKey]; !ok || val == "" {
o.Type = templates.File
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/")
}
} }
output = append(output, o) output = append(output, o)

View file

@ -501,6 +501,32 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")}, {Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")},
} }
tmplConfigUserIncludes := &templates.Templates{
SSH: &templates.SSHTemplates{
User: []templates.Template{
{Name: "step_includes.tpl", Type: templates.PrependLine, TemplatePath: "./testdata/templates/step_includes.tpl", Path: "${STEPPATH}/ssh/includes", Comment: "#"},
},
},
Data: map[string]interface{}{
"Step": &templates.Step{
SSH: templates.StepSSH{
UserKey: user,
HostKey: host,
},
},
},
}
userOutputEmptyData := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"<no value>/ssh/config\"\n")},
}
userOutputWithoutTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
userOutputWithTemplateVersion := []templates.Output{
{Name: "step_includes.tpl", Type: templates.PrependLine, Comment: "#", Path: "${STEPPATH}/ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
}
tmplConfigErr := &templates.Templates{ tmplConfigErr := &templates.Templates{
SSH: &templates.SSHTemplates{ SSH: &templates.SSHTemplates{
User: []templates.Template{ User: []templates.Template{
@ -542,6 +568,9 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
{"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false}, {"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false},
{"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false}, {"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false},
{"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false}, {"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false},
{"userIncludesEmptyData", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", nil}, userOutputEmptyData, false},
{"userIncludesWithoutTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithoutTemplateVersion, false},
{"userIncludesWithTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step", "StepSSHTemplateVersion": "v2"}}, userOutputWithTemplateVersion, false},
{"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true}, {"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true},
{"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true}, {"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true},
{"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true}, {"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true},

View file

@ -0,0 +1 @@
{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}

View file

@ -10,7 +10,6 @@ import (
"net/url" "net/url"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/cli-utils/step"
) )
// Client wraps http.Client with a transport using the step root and identity. // Client wraps http.Client with a transport using the step root and identity.
@ -28,7 +27,7 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL {
// $STEPPATH/config/defaults.json and the identity defined in // $STEPPATH/config/defaults.json and the identity defined in
// $STEPPATH/config/identity.json // $STEPPATH/config/identity.json
func LoadClient() (*Client, error) { func LoadClient() (*Client, error) {
defaultsFile := step.DefaultsFile() defaultsFile := DefaultsFile()
b, err := ioutil.ReadFile(defaultsFile) b, err := ioutil.ReadFile(defaultsFile)
if err != nil { if err != nil {
return nil, errors.Wrapf(err, "error reading %s", defaultsFile) return nil, errors.Wrapf(err, "error reading %s", defaultsFile)
@ -54,7 +53,7 @@ func LoadClient() (*Client, error) {
return nil, err return nil, err
} }
if err := identity.Validate(); err != nil { if err := identity.Validate(); err != nil {
return nil, errors.Wrapf(err, "error validating %s", step.IdentityFile()) return nil, errors.Wrapf(err, "error validating %s", IdentityFile())
} }
if kind := identity.Kind(); kind != MutualTLS { if kind := identity.Kind(); kind != MutualTLS {
return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind) return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind)

View file

@ -11,6 +11,12 @@ import (
"testing" "testing"
) )
func returnInput(val string) func() string {
return func() string {
return val
}
}
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
oldIdentityFile := IdentityFile oldIdentityFile := IdentityFile
oldDefaultsFile := DefaultsFile oldDefaultsFile := DefaultsFile
@ -19,8 +25,8 @@ func TestClient(t *testing.T) {
DefaultsFile = oldDefaultsFile DefaultsFile = oldDefaultsFile
}() }()
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
client, err := LoadClient() client, err := LoadClient()
if err != nil { if err != nil {
@ -140,36 +146,36 @@ func TestLoadClient(t *testing.T) {
wantErr bool wantErr bool
}{ }{
{"ok", func() { {"ok", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, expected, false}, }, expected, false},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/missing.json" IdentityFile = returnInput("testdata/config/missing.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail identity", func() { {"fail identity", func() {
IdentityFile = "testdata/config/fail.json" IdentityFile = returnInput("testdata/config/fail.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/missing.json" DefaultsFile = returnInput("testdata/config/missing.json")
}, nil, true}, }, nil, true},
{"fail defaults", func() { {"fail defaults", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/fail.json" DefaultsFile = returnInput("testdata/config/fail.json")
}, nil, true}, }, nil, true},
{"fail ca", func() { {"fail ca", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badca.json" DefaultsFile = returnInput("testdata/config/badca.json")
}, nil, true}, }, nil, true},
{"fail root", func() { {"fail root", func() {
IdentityFile = "testdata/config/identity.json" IdentityFile = returnInput("testdata/config/identity.json")
DefaultsFile = "testdata/config/badroot.json" DefaultsFile = returnInput("testdata/config/badroot.json")
}, nil, true}, }, nil, true},
{"fail type", func() { {"fail type", func() {
IdentityFile = "testdata/config/badIdentity.json" IdentityFile = returnInput("testdata/config/badIdentity.json")
DefaultsFile = "testdata/config/defaults.json" DefaultsFile = returnInput("testdata/config/defaults.json")
}, nil, true}, }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {

View file

@ -39,6 +39,19 @@ const TunnelTLS Type = "tTLS"
// DefaultLeeway is the duration for matching not before claims. // DefaultLeeway is the duration for matching not before claims.
const DefaultLeeway = 1 * time.Minute const DefaultLeeway = 1 * time.Minute
var (
identityDir = step.IdentityPath
configDir = step.ConfigPath
// IdentityFile contains a pointer to a function that outputs the location of
// the identity file.
IdentityFile = step.IdentityFile
// DefaultsFile contains a prointer a function that outputs the location of the
// defaults configuration file.
DefaultsFile = step.DefaultsFile
)
// Identity represents the identity file that can be used to authenticate with // Identity represents the identity file that can be used to authenticate with
// the CA. // the CA.
type Identity struct { type Identity struct {
@ -68,25 +81,17 @@ func LoadIdentity(filename string) (*Identity, error) {
// LoadDefaultIdentity loads the default identity. // LoadDefaultIdentity loads the default identity.
func LoadDefaultIdentity() (*Identity, error) { func LoadDefaultIdentity() (*Identity, error) {
return LoadIdentity(step.IdentityFile()) return LoadIdentity(IdentityFile())
}
func profileConfigDir() string {
return filepath.Join(step.Path(), "config")
}
func profileIdentityDir() string {
return filepath.Join(step.Path(), "identity")
} }
// WriteDefaultIdentity writes the given certificates and key and the // WriteDefaultIdentity writes the given certificates and key and the
// identity.json pointing to the new files. // identity.json pointing to the new files.
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
if err := os.MkdirAll(profileConfigDir(), 0700); err != nil { if err := os.MkdirAll(configDir(), 0700); err != nil {
return errors.Wrap(err, "error creating config directory") return errors.Wrap(err, "error creating config directory")
} }
identityDir := profileIdentityDir() identityDir := identityDir()
if err := os.MkdirAll(identityDir, 0700); err != nil { if err := os.MkdirAll(identityDir, 0700); err != nil {
return errors.Wrap(err, "error creating identity directory") return errors.Wrap(err, "error creating identity directory")
} }
@ -123,7 +128,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
}); err != nil { }); err != nil {
return errors.Wrap(err, "error writing identity json") return errors.Wrap(err, "error writing identity json")
} }
if err := ioutil.WriteFile(step.IdentityFile(), buf.Bytes(), 0600); err != nil { if err := ioutil.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }
@ -132,7 +137,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
// WriteIdentityCertificate writes the identity certificate to disk. // WriteIdentityCertificate writes the identity certificate to disk.
func WriteIdentityCertificate(certChain []api.Certificate) error { func WriteIdentityCertificate(certChain []api.Certificate) error {
filename := filepath.Join(profileIdentityDir(), "identity.crt") filename := filepath.Join(identityDir(), "identity.crt")
return writeCertificate(filename, certChain) return writeCertificate(filename, certChain)
} }
@ -315,7 +320,7 @@ func (i *Identity) Renew(client Renewer) error {
return errors.Wrap(err, "error encoding identity certificate") return errors.Wrap(err, "error encoding identity certificate")
} }
} }
certFilename := filepath.Join(profileIdentityDir(), "identity.crt") certFilename := filepath.Join(identityDir(), "identity.crt")
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
return errors.Wrap(err, "error writing identity certificate") return errors.Wrap(err, "error writing identity certificate")
} }

View file

@ -33,9 +33,9 @@ func TestLoadDefaultIdentity(t *testing.T) {
want *Identity want *Identity
wantErr bool wantErr bool
}{ }{
{"ok", func() { IdentityFile = "testdata/config/identity.json" }, expected, false}, {"ok", func() { IdentityFile = returnInput("testdata/config/identity.json") }, expected, false},
{"fail read", func() { IdentityFile = "testdata/config/missing.json" }, nil, true}, {"fail read", func() { IdentityFile = returnInput("testdata/config/missing.json") }, nil, true},
{"fail unmarshal", func() { IdentityFile = "testdata/config/fail.json" }, nil, true}, {"fail unmarshal", func() { IdentityFile = returnInput("testdata/config/fail.json") }, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -217,9 +217,9 @@ func TestWriteDefaultIdentity(t *testing.T) {
certChain = append(certChain, api.Certificate{Certificate: c}) certChain = append(certChain, api.Certificate{Certificate: c})
} }
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(tmpDir, "config", "identity.json") IdentityFile = returnInput(filepath.Join(tmpDir, "config", "identity.json"))
type args struct { type args struct {
certChain []api.Certificate certChain []api.Certificate
@ -233,27 +233,27 @@ func TestWriteDefaultIdentity(t *testing.T) {
}{ }{
{"ok", func() {}, args{certChain, key}, false}, {"ok", func() {}, args{certChain, key}, false},
{"fail mkdir config", func() { {"fail mkdir config", func() {
configDir = filepath.Join(tmpDir, "identity", "identity.crt") configDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail mkdir identity", func() { {"fail mkdir identity", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity", "identity.crt") identityDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail certificate", func() { {"fail certificate", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
{"fail key", func() { {"fail key", func() {
configDir = filepath.Join(tmpDir, "config") configDir = returnInput(filepath.Join(tmpDir, "config"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
}, args{certChain, "badKey"}, true}, }, args{certChain, "badKey"}, true},
{"fail write identity", func() { {"fail write identity", func() {
configDir = filepath.Join(tmpDir, "bad-dir") configDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
identityDir = filepath.Join(tmpDir, "identity") identityDir = returnInput(filepath.Join(tmpDir, "identity"))
IdentityFile = filepath.Join(configDir, "identity.json") IdentityFile = returnInput(filepath.Join(configDir(), "identity.json"))
os.MkdirAll(configDir, 0600) os.MkdirAll(configDir(), 0600)
}, args{certChain, key}, true}, }, args{certChain, key}, true},
} }
@ -377,7 +377,7 @@ func TestIdentity_Renew(t *testing.T) {
} }
oldIdentityDir := identityDir oldIdentityDir := identityDir
identityDir = "testdata/identity" identityDir = returnInput("testdata/identity")
defer func() { defer func() {
identityDir = oldIdentityDir identityDir = oldIdentityDir
os.RemoveAll(tmpDir) os.RemoveAll(tmpDir)
@ -432,8 +432,8 @@ func TestIdentity_Renew(t *testing.T) {
{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true}, {"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true}, {"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
{"fail write identity", func() { {"fail write identity", func() {
identityDir = filepath.Join(tmpDir, "bad-dir") identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
os.MkdirAll(identityDir, 0600) os.MkdirAll(identityDir(), 0600)
}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true}, }, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
} }
for _, tt := range tests { for _, tt := range tests {