forked from TrueCloudLab/certificates
SSH backwards compat updates
- use existence of new value in data map as boolean - add tests for backwards and forwards compatibility - fix old tests that used static dir locations
This commit is contained in:
parent
d37313bef4
commit
a7d144996f
7 changed files with 105 additions and 61 deletions
|
@ -102,10 +102,14 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Backwards compatibility for version of the cli older than v0.18.0
|
||||
if o.Name == "step_includes.tpl" && (data == nil || data[templates.SSHTemplateVersionKey] != "v2") {
|
||||
o.Type = templates.File
|
||||
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/")
|
||||
// Backwards compatibility for version of the cli older than v0.18.0.
|
||||
// Before v0.18.0 we were not passing any value for SSHTemplateVersionKey
|
||||
// from the cli.
|
||||
if o.Name == "step_includes.tpl" {
|
||||
if val, ok := data[templates.SSHTemplateVersionKey]; !ok || val == "" {
|
||||
o.Type = templates.File
|
||||
o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/")
|
||||
}
|
||||
}
|
||||
|
||||
output = append(output, o)
|
||||
|
|
|
@ -501,6 +501,32 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
|
|||
{Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")},
|
||||
}
|
||||
|
||||
tmplConfigUserIncludes := &templates.Templates{
|
||||
SSH: &templates.SSHTemplates{
|
||||
User: []templates.Template{
|
||||
{Name: "step_includes.tpl", Type: templates.PrependLine, TemplatePath: "./testdata/templates/step_includes.tpl", Path: "${STEPPATH}/ssh/includes", Comment: "#"},
|
||||
},
|
||||
},
|
||||
Data: map[string]interface{}{
|
||||
"Step": &templates.Step{
|
||||
SSH: templates.StepSSH{
|
||||
UserKey: user,
|
||||
HostKey: host,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
userOutputEmptyData := []templates.Output{
|
||||
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"<no value>/ssh/config\"\n")},
|
||||
}
|
||||
userOutputWithoutTemplateVersion := []templates.Output{
|
||||
{Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
|
||||
}
|
||||
userOutputWithTemplateVersion := []templates.Output{
|
||||
{Name: "step_includes.tpl", Type: templates.PrependLine, Comment: "#", Path: "${STEPPATH}/ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")},
|
||||
}
|
||||
|
||||
tmplConfigErr := &templates.Templates{
|
||||
SSH: &templates.SSHTemplates{
|
||||
User: []templates.Template{
|
||||
|
@ -542,6 +568,9 @@ func TestAuthority_GetSSHConfig(t *testing.T) {
|
|||
{"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false},
|
||||
{"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false},
|
||||
{"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false},
|
||||
{"userIncludesEmptyData", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", nil}, userOutputEmptyData, false},
|
||||
{"userIncludesWithoutTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithoutTemplateVersion, false},
|
||||
{"userIncludesWithTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step", "StepSSHTemplateVersion": "v2"}}, userOutputWithTemplateVersion, false},
|
||||
{"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true},
|
||||
{"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true},
|
||||
{"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true},
|
||||
|
|
1
authority/testdata/templates/step_includes.tpl
vendored
Normal file
1
authority/testdata/templates/step_includes.tpl
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }}
|
|
@ -10,7 +10,6 @@ import (
|
|||
"net/url"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"go.step.sm/cli-utils/step"
|
||||
)
|
||||
|
||||
// Client wraps http.Client with a transport using the step root and identity.
|
||||
|
@ -28,7 +27,7 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL {
|
|||
// $STEPPATH/config/defaults.json and the identity defined in
|
||||
// $STEPPATH/config/identity.json
|
||||
func LoadClient() (*Client, error) {
|
||||
defaultsFile := step.DefaultsFile()
|
||||
defaultsFile := DefaultsFile()
|
||||
b, err := ioutil.ReadFile(defaultsFile)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "error reading %s", defaultsFile)
|
||||
|
@ -54,7 +53,7 @@ func LoadClient() (*Client, error) {
|
|||
return nil, err
|
||||
}
|
||||
if err := identity.Validate(); err != nil {
|
||||
return nil, errors.Wrapf(err, "error validating %s", step.IdentityFile())
|
||||
return nil, errors.Wrapf(err, "error validating %s", IdentityFile())
|
||||
}
|
||||
if kind := identity.Kind(); kind != MutualTLS {
|
||||
return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind)
|
||||
|
|
|
@ -11,6 +11,12 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func returnInput(val string) func() string {
|
||||
return func() string {
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
oldIdentityFile := IdentityFile
|
||||
oldDefaultsFile := DefaultsFile
|
||||
|
@ -19,8 +25,8 @@ func TestClient(t *testing.T) {
|
|||
DefaultsFile = oldDefaultsFile
|
||||
}()
|
||||
|
||||
IdentityFile = "testdata/config/identity.json"
|
||||
DefaultsFile = "testdata/config/defaults.json"
|
||||
IdentityFile = returnInput("testdata/config/identity.json")
|
||||
DefaultsFile = returnInput("testdata/config/defaults.json")
|
||||
|
||||
client, err := LoadClient()
|
||||
if err != nil {
|
||||
|
@ -140,36 +146,36 @@ func TestLoadClient(t *testing.T) {
|
|||
wantErr bool
|
||||
}{
|
||||
{"ok", func() {
|
||||
IdentityFile = "testdata/config/identity.json"
|
||||
DefaultsFile = "testdata/config/defaults.json"
|
||||
IdentityFile = returnInput("testdata/config/identity.json")
|
||||
DefaultsFile = returnInput("testdata/config/defaults.json")
|
||||
}, expected, false},
|
||||
{"fail identity", func() {
|
||||
IdentityFile = "testdata/config/missing.json"
|
||||
DefaultsFile = "testdata/config/defaults.json"
|
||||
IdentityFile = returnInput("testdata/config/missing.json")
|
||||
DefaultsFile = returnInput("testdata/config/defaults.json")
|
||||
}, nil, true},
|
||||
{"fail identity", func() {
|
||||
IdentityFile = "testdata/config/fail.json"
|
||||
DefaultsFile = "testdata/config/defaults.json"
|
||||
IdentityFile = returnInput("testdata/config/fail.json")
|
||||
DefaultsFile = returnInput("testdata/config/defaults.json")
|
||||
}, nil, true},
|
||||
{"fail defaults", func() {
|
||||
IdentityFile = "testdata/config/identity.json"
|
||||
DefaultsFile = "testdata/config/missing.json"
|
||||
IdentityFile = returnInput("testdata/config/identity.json")
|
||||
DefaultsFile = returnInput("testdata/config/missing.json")
|
||||
}, nil, true},
|
||||
{"fail defaults", func() {
|
||||
IdentityFile = "testdata/config/identity.json"
|
||||
DefaultsFile = "testdata/config/fail.json"
|
||||
IdentityFile = returnInput("testdata/config/identity.json")
|
||||
DefaultsFile = returnInput("testdata/config/fail.json")
|
||||
}, nil, true},
|
||||
{"fail ca", func() {
|
||||
IdentityFile = "testdata/config/identity.json"
|
||||
DefaultsFile = "testdata/config/badca.json"
|
||||
IdentityFile = returnInput("testdata/config/identity.json")
|
||||
DefaultsFile = returnInput("testdata/config/badca.json")
|
||||
}, nil, true},
|
||||
{"fail root", func() {
|
||||
IdentityFile = "testdata/config/identity.json"
|
||||
DefaultsFile = "testdata/config/badroot.json"
|
||||
IdentityFile = returnInput("testdata/config/identity.json")
|
||||
DefaultsFile = returnInput("testdata/config/badroot.json")
|
||||
}, nil, true},
|
||||
{"fail type", func() {
|
||||
IdentityFile = "testdata/config/badIdentity.json"
|
||||
DefaultsFile = "testdata/config/defaults.json"
|
||||
IdentityFile = returnInput("testdata/config/badIdentity.json")
|
||||
DefaultsFile = returnInput("testdata/config/defaults.json")
|
||||
}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
|
|
@ -39,6 +39,19 @@ const TunnelTLS Type = "tTLS"
|
|||
// DefaultLeeway is the duration for matching not before claims.
|
||||
const DefaultLeeway = 1 * time.Minute
|
||||
|
||||
var (
|
||||
identityDir = step.IdentityPath
|
||||
configDir = step.ConfigPath
|
||||
|
||||
// IdentityFile contains a pointer to a function that outputs the location of
|
||||
// the identity file.
|
||||
IdentityFile = step.IdentityFile
|
||||
|
||||
// DefaultsFile contains a prointer a function that outputs the location of the
|
||||
// defaults configuration file.
|
||||
DefaultsFile = step.DefaultsFile
|
||||
)
|
||||
|
||||
// Identity represents the identity file that can be used to authenticate with
|
||||
// the CA.
|
||||
type Identity struct {
|
||||
|
@ -68,25 +81,17 @@ func LoadIdentity(filename string) (*Identity, error) {
|
|||
|
||||
// LoadDefaultIdentity loads the default identity.
|
||||
func LoadDefaultIdentity() (*Identity, error) {
|
||||
return LoadIdentity(step.IdentityFile())
|
||||
}
|
||||
|
||||
func profileConfigDir() string {
|
||||
return filepath.Join(step.Path(), "config")
|
||||
}
|
||||
|
||||
func profileIdentityDir() string {
|
||||
return filepath.Join(step.Path(), "identity")
|
||||
return LoadIdentity(IdentityFile())
|
||||
}
|
||||
|
||||
// WriteDefaultIdentity writes the given certificates and key and the
|
||||
// identity.json pointing to the new files.
|
||||
func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error {
|
||||
if err := os.MkdirAll(profileConfigDir(), 0700); err != nil {
|
||||
if err := os.MkdirAll(configDir(), 0700); err != nil {
|
||||
return errors.Wrap(err, "error creating config directory")
|
||||
}
|
||||
|
||||
identityDir := profileIdentityDir()
|
||||
identityDir := identityDir()
|
||||
if err := os.MkdirAll(identityDir, 0700); err != nil {
|
||||
return errors.Wrap(err, "error creating identity directory")
|
||||
}
|
||||
|
@ -123,7 +128,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
|
|||
}); err != nil {
|
||||
return errors.Wrap(err, "error writing identity json")
|
||||
}
|
||||
if err := ioutil.WriteFile(step.IdentityFile(), buf.Bytes(), 0600); err != nil {
|
||||
if err := ioutil.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil {
|
||||
return errors.Wrap(err, "error writing identity certificate")
|
||||
}
|
||||
|
||||
|
@ -132,7 +137,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er
|
|||
|
||||
// WriteIdentityCertificate writes the identity certificate to disk.
|
||||
func WriteIdentityCertificate(certChain []api.Certificate) error {
|
||||
filename := filepath.Join(profileIdentityDir(), "identity.crt")
|
||||
filename := filepath.Join(identityDir(), "identity.crt")
|
||||
return writeCertificate(filename, certChain)
|
||||
}
|
||||
|
||||
|
@ -315,7 +320,7 @@ func (i *Identity) Renew(client Renewer) error {
|
|||
return errors.Wrap(err, "error encoding identity certificate")
|
||||
}
|
||||
}
|
||||
certFilename := filepath.Join(profileIdentityDir(), "identity.crt")
|
||||
certFilename := filepath.Join(identityDir(), "identity.crt")
|
||||
if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil {
|
||||
return errors.Wrap(err, "error writing identity certificate")
|
||||
}
|
||||
|
|
|
@ -33,9 +33,9 @@ func TestLoadDefaultIdentity(t *testing.T) {
|
|||
want *Identity
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", func() { IdentityFile = "testdata/config/identity.json" }, expected, false},
|
||||
{"fail read", func() { IdentityFile = "testdata/config/missing.json" }, nil, true},
|
||||
{"fail unmarshal", func() { IdentityFile = "testdata/config/fail.json" }, nil, true},
|
||||
{"ok", func() { IdentityFile = returnInput("testdata/config/identity.json") }, expected, false},
|
||||
{"fail read", func() { IdentityFile = returnInput("testdata/config/missing.json") }, nil, true},
|
||||
{"fail unmarshal", func() { IdentityFile = returnInput("testdata/config/fail.json") }, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
@ -217,9 +217,9 @@ func TestWriteDefaultIdentity(t *testing.T) {
|
|||
certChain = append(certChain, api.Certificate{Certificate: c})
|
||||
}
|
||||
|
||||
configDir = filepath.Join(tmpDir, "config")
|
||||
identityDir = filepath.Join(tmpDir, "identity")
|
||||
IdentityFile = filepath.Join(tmpDir, "config", "identity.json")
|
||||
configDir = returnInput(filepath.Join(tmpDir, "config"))
|
||||
identityDir = returnInput(filepath.Join(tmpDir, "identity"))
|
||||
IdentityFile = returnInput(filepath.Join(tmpDir, "config", "identity.json"))
|
||||
|
||||
type args struct {
|
||||
certChain []api.Certificate
|
||||
|
@ -233,27 +233,27 @@ func TestWriteDefaultIdentity(t *testing.T) {
|
|||
}{
|
||||
{"ok", func() {}, args{certChain, key}, false},
|
||||
{"fail mkdir config", func() {
|
||||
configDir = filepath.Join(tmpDir, "identity", "identity.crt")
|
||||
identityDir = filepath.Join(tmpDir, "identity")
|
||||
configDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
|
||||
identityDir = returnInput(filepath.Join(tmpDir, "identity"))
|
||||
}, args{certChain, key}, true},
|
||||
{"fail mkdir identity", func() {
|
||||
configDir = filepath.Join(tmpDir, "config")
|
||||
identityDir = filepath.Join(tmpDir, "identity", "identity.crt")
|
||||
configDir = returnInput(filepath.Join(tmpDir, "config"))
|
||||
identityDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt"))
|
||||
}, args{certChain, key}, true},
|
||||
{"fail certificate", func() {
|
||||
configDir = filepath.Join(tmpDir, "config")
|
||||
identityDir = filepath.Join(tmpDir, "bad-dir")
|
||||
os.MkdirAll(identityDir, 0600)
|
||||
configDir = returnInput(filepath.Join(tmpDir, "config"))
|
||||
identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
|
||||
os.MkdirAll(identityDir(), 0600)
|
||||
}, args{certChain, key}, true},
|
||||
{"fail key", func() {
|
||||
configDir = filepath.Join(tmpDir, "config")
|
||||
identityDir = filepath.Join(tmpDir, "identity")
|
||||
configDir = returnInput(filepath.Join(tmpDir, "config"))
|
||||
identityDir = returnInput(filepath.Join(tmpDir, "identity"))
|
||||
}, args{certChain, "badKey"}, true},
|
||||
{"fail write identity", func() {
|
||||
configDir = filepath.Join(tmpDir, "bad-dir")
|
||||
identityDir = filepath.Join(tmpDir, "identity")
|
||||
IdentityFile = filepath.Join(configDir, "identity.json")
|
||||
os.MkdirAll(configDir, 0600)
|
||||
configDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
|
||||
identityDir = returnInput(filepath.Join(tmpDir, "identity"))
|
||||
IdentityFile = returnInput(filepath.Join(configDir(), "identity.json"))
|
||||
os.MkdirAll(configDir(), 0600)
|
||||
}, args{certChain, key}, true},
|
||||
}
|
||||
|
||||
|
@ -377,7 +377,7 @@ func TestIdentity_Renew(t *testing.T) {
|
|||
}
|
||||
|
||||
oldIdentityDir := identityDir
|
||||
identityDir = "testdata/identity"
|
||||
identityDir = returnInput("testdata/identity")
|
||||
defer func() {
|
||||
identityDir = oldIdentityDir
|
||||
os.RemoveAll(tmpDir)
|
||||
|
@ -432,8 +432,8 @@ func TestIdentity_Renew(t *testing.T) {
|
|||
{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
|
||||
{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
|
||||
{"fail write identity", func() {
|
||||
identityDir = filepath.Join(tmpDir, "bad-dir")
|
||||
os.MkdirAll(identityDir, 0600)
|
||||
identityDir = returnInput(filepath.Join(tmpDir, "bad-dir"))
|
||||
os.MkdirAll(identityDir(), 0600)
|
||||
}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
|
Loading…
Reference in a new issue