Check for required variables in templates.

Fixes smallstep/cli#232
This commit is contained in:
Mariano Cano 2020-06-16 17:26:54 -07:00
parent 6c844a0618
commit 237baa5169
3 changed files with 74 additions and 0 deletions

View file

@ -159,6 +159,15 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin
// Render templates
output := []templates.Output{}
for _, t := range ts {
if err := t.Load(); err != nil {
return nil, err
}
// Check for required variables.
if err := t.ValidateRequiredData(data); err != nil {
return nil, errs.BadRequestErr(err, errs.WithMessage("%v, please use `--set <key=value>` flag", err))
}
o, err := t.Output(mergedData)
if err != nil {
return nil, err

View file

@ -106,6 +106,7 @@ type Template struct {
TemplatePath string `json:"template"`
Path string `json:"path"`
Comment string `json:"comment"`
RequiredData []string `json:"requires"`
Content []byte `json:"-"`
}
@ -147,6 +148,17 @@ func (t *Template) Validate() error {
return nil
}
// ValidateRequiredData checks that the given data contains all the keys
// required.
func (t *Template) ValidateRequiredData(data map[string]string) error {
for _, key := range t.RequiredData {
if _, ok := data[key]; !ok {
return errors.Errorf("required variable '%s' is missing", key)
}
}
return nil
}
// Load loads the template in memory, returns an error if the parsing of the
// template fails.
func (t *Template) Load() error {
@ -166,7 +178,10 @@ func (t *Template) Load() error {
return nil
}
// LoadBytes loads the template in memory, returns an error if the parsing of
// the template fails.
func (t *Template) LoadBytes(b []byte) error {
t.backfill(b)
tmpl, err := template.New(t.Name).Funcs(sprig.TxtFuncMap()).Parse(string(b))
if err != nil {
return errors.Wrapf(err, "error parsing template %s", t.Name)
@ -209,6 +224,20 @@ func (t *Template) Output(data interface{}) (Output, error) {
}, nil
}
// backfill updates old templates with the required data.
func (t *Template) backfill(b []byte) {
switch t.Name {
case "sshd_config.tpl":
if len(t.RequiredData) == 0 {
a := bytes.TrimSpace(b)
b := bytes.TrimSpace([]byte(DefaultSSHTemplateData[t.Name]))
if bytes.Equal(a, b) {
t.RequiredData = []string{"Certificate", "Key"}
}
}
}
}
// Output represents the text representation of a rendered template.
type Output struct {
Name string `json:"name"`

View file

@ -428,3 +428,39 @@ func TestOutput_Write(t *testing.T) {
})
}
}
func TestTemplate_ValidateRequiredData(t *testing.T) {
data := map[string]string{
"key1": "value1",
"key2": "value2",
}
type fields struct {
RequiredData []string
}
type args struct {
data map[string]string
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{"ok nil", fields{nil}, args{nil}, false},
{"ok empty", fields{[]string{}}, args{data}, false},
{"ok one", fields{[]string{"key1"}}, args{data}, false},
{"ok multiple", fields{[]string{"key1", "key2"}}, args{data}, false},
{"fail nil", fields{[]string{"missing"}}, args{nil}, true},
{"fail missing", fields{[]string{"missing"}}, args{data}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpl := &Template{
RequiredData: tt.fields.RequiredData,
}
if err := tmpl.ValidateRequiredData(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("Template.ValidateRequiredData() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}