From 152b7c56a06fcdc3b83b6ecbf4cc8a0e6231d66d Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 15 Oct 2019 18:00:46 -0700 Subject: [PATCH] Add tests for templates and some fixes. --- templates/templates.go | 56 +++-- templates/templates_test.go | 420 ++++++++++++++++++++++++++++++++++++ 2 files changed, 457 insertions(+), 19 deletions(-) create mode 100644 templates/templates_test.go diff --git a/templates/templates.go b/templates/templates.go index ee4a5791..41a7e0f3 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -57,19 +57,21 @@ func (t *Templates) Validate() (err error) { // LoadAll preloads all templates in memory. It returns an error if an error is // found parsing at least one template. func LoadAll(t *Templates) (err error) { - if t.SSH != nil { - for _, tt := range t.SSH.User { - if err = tt.Load(); err != nil { - return err + if t != nil { + if t.SSH != nil { + for _, tt := range t.SSH.User { + if err = tt.Load(); err != nil { + return + } } - } - for _, tt := range t.SSH.Host { - if err = tt.Load(); err != nil { - return err + for _, tt := range t.SSH.Host { + if err = tt.Load(); err != nil { + return + } } } } - return nil + return } // SSHTemplates contains the templates defining ssh configuration files. @@ -113,18 +115,30 @@ func (t *Template) Validate() error { return nil case t.Name == "": return errors.New("template name cannot be empty") - case t.TemplatePath == "": + case t.Type != Snippet && t.Type != File && t.Type != Directory: + return errors.Errorf("invalid template type %s, it must be %s, %s, or %s", t.Type, Snippet, File, Directory) + case t.TemplatePath == "" && t.Type != Directory: return errors.New("template template cannot be empty") + case t.TemplatePath != "" && t.Type == Directory: + return errors.New("template template must be empty with directory type") case t.Path == "": return errors.New("template path cannot be empty") } - // Defaults - if t.Type == "" { - t.Type = Snippet - } - if t.Comment == "" { - t.Comment = "#" + if t.TemplatePath != "" { + // Check for file + st, err := os.Stat(config.StepAbs(t.TemplatePath)) + if err != nil { + return errors.Wrapf(err, "error reading %s", t.TemplatePath) + } + if st.IsDir() { + return errors.Errorf("error reading %s: is not a file", t.TemplatePath) + } + + // Defaults + if t.Comment == "" { + t.Comment = "#" + } } return nil @@ -133,7 +147,7 @@ func (t *Template) Validate() error { // Load loads the template in memory, returns an error if the parsing of the // template fails. func (t *Template) Load() error { - if t.Template == nil { + if t.Template == nil && t.Type != Directory { filename := config.StepAbs(t.TemplatePath) b, err := ioutil.ReadFile(filename) if err != nil { @@ -151,6 +165,10 @@ func (t *Template) Load() error { // Render executes the template with the given data and returns the rendered // version. func (t *Template) Render(data interface{}) ([]byte, error) { + if t.Type == Directory { + return nil, nil + } + if err := t.Load(); err != nil { return nil, err } @@ -172,8 +190,8 @@ func (t *Template) Output(data interface{}) (Output, error) { return Output{ Name: t.Name, Type: t.Type, - Comment: t.Comment, Path: t.Path, + Comment: t.Comment, Content: b, }, nil } @@ -182,8 +200,8 @@ func (t *Template) Output(data interface{}) (Output, error) { type Output struct { Name string `json:"name"` Type TemplateType `json:"type"` - Comment string `json:"comment"` Path string `json:"path"` + Comment string `json:"comment"` Content []byte `json:"content"` } diff --git a/templates/templates_test.go b/templates/templates_test.go new file mode 100644 index 00000000..537fab4c --- /dev/null +++ b/templates/templates_test.go @@ -0,0 +1,420 @@ +package templates + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/base64" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/smallstep/assert" + "golang.org/x/crypto/ssh" +) + +func TestTemplates_Validate(t *testing.T) { + sshTemplates := &SSHTemplates{ + User: []Template{ + {Name: "known_host.tpl", Type: File, TemplatePath: "../authority/testdata/templates/known_hosts.tpl", Path: "ssh/known_host", Comment: "#"}, + }, + Host: []Template{ + {Name: "ca.tpl", Type: File, TemplatePath: "../authority/testdata/templates/ca.tpl", Path: "/etc/ssh/ca.pub", Comment: "#"}, + }, + } + type fields struct { + SSH *SSHTemplates + Data map[string]interface{} + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"ok", fields{sshTemplates, nil}, false}, + {"okWithData", fields{sshTemplates, map[string]interface{}{"Foo": "Bar"}}, false}, + {"badSSH", fields{&SSHTemplates{User: []Template{{}}}, nil}, true}, + {"badDataUser", fields{sshTemplates, map[string]interface{}{"User": "Bar"}}, true}, + {"badDataStep", fields{sshTemplates, map[string]interface{}{"Step": "Bar"}}, true}, + } + var nilValue *Templates + assert.NoError(t, nilValue.Validate()) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpl := &Templates{ + SSH: tt.fields.SSH, + Data: tt.fields.Data, + } + if err := tmpl.Validate(); (err != nil) != tt.wantErr { + t.Errorf("Templates.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestSSHTemplates_Validate(t *testing.T) { + user := []Template{ + {Name: "include.tpl", Type: Snippet, TemplatePath: "../authority/testdata/templates/include.tpl", Path: "~/.ssh/config", Comment: "#"}, + } + host := []Template{ + {Name: "ca.tpl", Type: File, TemplatePath: "../authority/testdata/templates/ca.tpl", Path: "/etc/ssh/ca.pub", Comment: "#"}, + } + + type fields struct { + User []Template + Host []Template + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"ok", fields{user, host}, false}, + {"user", fields{user, nil}, false}, + {"host", fields{nil, host}, false}, + {"badUser", fields{[]Template{{}}, nil}, true}, + {"badHost", fields{nil, []Template{{}}}, true}, + } + var nilValue *SSHTemplates + assert.NoError(t, nilValue.Validate()) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpl := &SSHTemplates{ + User: tt.fields.User, + Host: tt.fields.Host, + } + if err := tmpl.Validate(); (err != nil) != tt.wantErr { + t.Errorf("SSHTemplates.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestTemplate_Validate(t *testing.T) { + okPath := "~/.ssh/config" + okTmplPath := "../authority/testdata/templates/include.tpl" + + type fields struct { + Name string + Type TemplateType + TemplatePath string + Path string + Comment string + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"okSnippet", fields{"include.tpl", Snippet, okTmplPath, okPath, "#"}, false}, + {"okFile", fields{"file.tpl", File, okTmplPath, okPath, "#"}, false}, + {"okDirectory", fields{"dir.tpl", Directory, "", "/tmp/dir", "#"}, false}, + {"badName", fields{"", Snippet, okTmplPath, okPath, "#"}, true}, + {"badType", fields{"include.tpl", "", okTmplPath, okPath, "#"}, true}, + {"badType", fields{"include.tpl", "foo", okTmplPath, okPath, "#"}, true}, + {"badTemplatePath", fields{"include.tpl", Snippet, "", okPath, "#"}, true}, + {"badTemplatePath", fields{"include.tpl", File, "", okPath, "#"}, true}, + {"badTemplatePath", fields{"include.tpl", Directory, okTmplPath, okPath, "#"}, true}, + {"badPath", fields{"include.tpl", Snippet, okTmplPath, "", "#"}, true}, + {"missingTemplate", fields{"include.tpl", Snippet, "./testdata/include.tpl", okTmplPath, "#"}, true}, + {"directoryTemplate", fields{"include.tpl", File, "../authority/testdata", okTmplPath, "#"}, true}, + } + var nilValue *Template + assert.NoError(t, nilValue.Validate()) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpl := &Template{ + Name: tt.fields.Name, + Type: tt.fields.Type, + TemplatePath: tt.fields.TemplatePath, + Path: tt.fields.Path, + Comment: tt.fields.Comment, + } + if err := tmpl.Validate(); (err != nil) != tt.wantErr { + t.Errorf("Template.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestLoadAll(t *testing.T) { + tmpl := &Templates{ + SSH: &SSHTemplates{ + User: []Template{ + {Name: "include.tpl", Type: Snippet, TemplatePath: "../authority/testdata/templates/include.tpl", Path: "~/.ssh/config", Comment: "#"}, + }, + Host: []Template{ + {Name: "ca.tpl", Type: File, TemplatePath: "../authority/testdata/templates/ca.tpl", Path: "/etc/ssh/ca.pub", Comment: "#"}, + }, + }, + } + + type args struct { + t *Templates + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{tmpl}, false}, + {"empty", args{&Templates{}}, false}, + {"nil", args{nil}, false}, + {"badUser", args{&Templates{SSH: &SSHTemplates{User: []Template{{}}}}}, true}, + {"badHost", args{&Templates{SSH: &SSHTemplates{Host: []Template{{}}}}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := LoadAll(tt.args.t); (err != nil) != tt.wantErr { + t.Errorf("LoadAll() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestTemplate_Load(t *testing.T) { + type fields struct { + Name string + Type TemplateType + TemplatePath string + Path string + Comment string + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"ok", fields{"include.tpl", Snippet, "../authority/testdata/templates/include.tpl", "~/.ssh/config", "#"}, false}, + {"error", fields{"error.tpl", Snippet, "../authority/testdata/templates/error.tpl", "/tmp/error", "#"}, true}, + {"missing", fields{"include.tpl", Snippet, "./testdata/include.tpl", "~/.ssh/config", "#"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpl := &Template{ + Name: tt.fields.Name, + Type: tt.fields.Type, + TemplatePath: tt.fields.TemplatePath, + Path: tt.fields.Path, + Comment: tt.fields.Comment, + } + if err := tmpl.Load(); (err != nil) != tt.wantErr { + t.Errorf("Template.Load() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestTemplate_Render(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + user, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) + + key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + host, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) + + data := map[string]interface{}{ + "Step": &Step{ + SSH: StepSSH{ + UserKey: user, + HostKey: host, + }, + }, + "User": map[string]string{ + "StepPath": "/tmp/.step", + }, + } + + type fields struct { + Name string + Type TemplateType + TemplatePath string + Path string + Comment string + } + type args struct { + data interface{} + } + tests := []struct { + name string + fields fields + args args + want []byte + wantErr bool + }{ + {"snippet", fields{"include.tpl", Snippet, "../authority/testdata/templates/include.tpl", "~/.ssh/config", "#"}, args{data}, []byte("Host *\n\tInclude /tmp/.step/ssh/config"), false}, + {"file", fields{"known_hosts.tpl", File, "../authority/testdata/templates/known_hosts.tpl", "ssh/known_hosts", "#"}, args{data}, []byte(fmt.Sprintf("@cert-authority * %s %s", host.Type(), hostB64)), false}, + {"file", fields{"ca.tpl", File, "../authority/testdata/templates/ca.tpl", "/etc/ssh/ca.pub", "#"}, args{data}, []byte(fmt.Sprintf("%s %s", user.Type(), userB64)), false}, + {"directory", fields{"dir.tpl", Directory, "", "/tmp/dir", ""}, args{data}, nil, false}, + {"error", fields{"error.tpl", File, "../authority/testdata/templates/error.tpl", "/tmp/error", "#"}, args{data}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpl := &Template{ + Name: tt.fields.Name, + Type: tt.fields.Type, + TemplatePath: tt.fields.TemplatePath, + Path: tt.fields.Path, + Comment: tt.fields.Comment, + } + got, err := tmpl.Render(tt.args.data) + if (err != nil) != tt.wantErr { + t.Errorf("Template.Render() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Template.Render() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTemplate_Output(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + user, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) + + key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.FatalError(t, err) + host, err := ssh.NewPublicKey(key.Public()) + assert.FatalError(t, err) + hostB64 := base64.StdEncoding.EncodeToString(host.Marshal()) + + data := map[string]interface{}{ + "Step": &Step{ + SSH: StepSSH{ + UserKey: user, + HostKey: host, + }, + }, + "User": map[string]string{ + "StepPath": "/tmp/.step", + }, + } + + type fields struct { + Name string + Type TemplateType + TemplatePath string + Path string + Comment string + } + type args struct { + data interface{} + } + tests := []struct { + name string + fields fields + args args + want []byte + wantErr bool + }{ + {"snippet", fields{"include.tpl", Snippet, "../authority/testdata/templates/include.tpl", "~/.ssh/config", "#"}, args{data}, []byte("Host *\n\tInclude /tmp/.step/ssh/config"), false}, + {"file", fields{"known_hosts.tpl", File, "../authority/testdata/templates/known_hosts.tpl", "ssh/known_hosts", "#"}, args{data}, []byte(fmt.Sprintf("@cert-authority * %s %s", host.Type(), hostB64)), false}, + {"file", fields{"ca.tpl", File, "../authority/testdata/templates/ca.tpl", "/etc/ssh/ca.pub", "#"}, args{data}, []byte(fmt.Sprintf("%s %s", user.Type(), userB64)), false}, + {"directory", fields{"dir.tpl", Directory, "", "/tmp/dir", ""}, args{data}, nil, false}, + {"error", fields{"error.tpl", File, "../authority/testdata/templates/error.tpl", "/tmp/error", "#"}, args{data}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var want Output + if !tt.wantErr { + want = Output{ + Name: tt.fields.Name, + Type: tt.fields.Type, + Path: tt.fields.Path, + Comment: tt.fields.Comment, + Content: tt.want, + } + } + + tmpl := &Template{ + Name: tt.fields.Name, + Type: tt.fields.Type, + TemplatePath: tt.fields.TemplatePath, + Path: tt.fields.Path, + Comment: tt.fields.Comment, + } + got, err := tmpl.Output(tt.args.data) + if (err != nil) != tt.wantErr { + t.Errorf("Template.Output() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, want) { + t.Errorf("Template.Output() = %v, want %v", got, want) + } + }) + } +} + +func TestOutput_Write(t *testing.T) { + dir, err := ioutil.TempDir("", "test-output-write") + assert.FatalError(t, err) + defer os.RemoveAll(dir) + + join := func(elem ...string) string { + elems := append([]string{dir}, elem...) + return filepath.Join(elems...) + } + assert.FatalError(t, os.Mkdir(join("bad"), 0644)) + + type fields struct { + Name string + Type TemplateType + Path string + Comment string + Content []byte + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + {"snippet", fields{"snippet", Snippet, join("snippet"), "#", []byte("some content")}, false}, + {"file", fields{"file", File, join("file"), "#", []byte("some content")}, false}, + {"snippetInDir", fields{"file", Snippet, join("dir", "snippets", "snippet"), "#", []byte("some content")}, false}, + {"fileInDir", fields{"file", File, join("dir", "files", "file"), "#", []byte("some content")}, false}, + {"directory", fields{"directory", Directory, join("directory"), "", nil}, false}, + {"snippetErr", fields{"snippet", Snippet, join("bad", "snippet"), "#", []byte("some content")}, true}, + {"fileErr", fields{"file", File, join("bad", "file"), "#", []byte("some content")}, true}, + {"directoryErr", fields{"directory", Directory, join("bad", "directory"), "", nil}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + o := &Output{ + Name: tt.fields.Name, + Type: tt.fields.Type, + Comment: tt.fields.Comment, + Path: tt.fields.Path, + Content: tt.fields.Content, + } + if err := o.Write(); (err != nil) != tt.wantErr { + t.Errorf("Output.Write() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + st, err := os.Stat(o.Path) + if err != nil { + t.Errorf("os.Stat(%s) error = %v", o.Path, err) + } else { + if o.Type == Directory { + assert.True(t, st.IsDir()) + assert.Equals(t, os.ModeDir|os.FileMode(0700), st.Mode()) + } else { + assert.False(t, st.IsDir()) + assert.Equals(t, os.FileMode(0600), st.Mode()) + } + } + } + }) + } +}