From af3eeb870e5ead176e83254119e2b282ec9dbedd Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Fri, 24 Jul 2020 17:08:32 -0700 Subject: [PATCH] Add package to generate ssh certificate for templates. --- sshutil/certificate.go | 104 +++++++++ sshutil/certificate_test.go | 347 ++++++++++++++++++++++++++++++ sshutil/options.go | 91 ++++++++ sshutil/options_test.go | 176 ++++++++++++++++ sshutil/templates.go | 132 ++++++++++++ sshutil/templates_test.go | 406 ++++++++++++++++++++++++++++++++++++ sshutil/testdata/github.tpl | 10 + sshutil/types.go | 65 ++++++ sshutil/types_test.go | 82 ++++++++ 9 files changed, 1413 insertions(+) create mode 100644 sshutil/certificate.go create mode 100644 sshutil/certificate_test.go create mode 100644 sshutil/options.go create mode 100644 sshutil/options_test.go create mode 100644 sshutil/templates.go create mode 100644 sshutil/templates_test.go create mode 100644 sshutil/testdata/github.tpl create mode 100644 sshutil/types_test.go diff --git a/sshutil/certificate.go b/sshutil/certificate.go new file mode 100644 index 00000000..473344a4 --- /dev/null +++ b/sshutil/certificate.go @@ -0,0 +1,104 @@ +package sshutil + +import ( + "crypto/rand" + "encoding/binary" + "encoding/json" + + "github.com/pkg/errors" + "github.com/smallstep/cli/crypto/randutil" + "golang.org/x/crypto/ssh" +) + +// Certificate is the json representation of ssh.Certificate. +type Certificate struct { + Nonce []byte `json:"nonce"` + Key ssh.PublicKey `json:"-"` + Serial uint64 `json:"serial"` + Type CertType `json:"type"` + KeyID string `json:"keyId"` + Principals []string `json:"principals"` + ValidAfter uint64 `json:"-"` + ValidBefore uint64 `json:"-"` + CriticalOptions map[string]string `json:"criticalOptions"` + Extensions map[string]string `json:"extensions"` + Reserved []byte `json:"reserved"` + SignatureKey ssh.PublicKey `json:"-"` + Signature *ssh.Signature `json:"-"` +} + +// NewCertificate creates a new certificate with the given key after parsing a +// template given in the options. +func NewCertificate(key ssh.PublicKey, opts ...Option) (*Certificate, error) { + o, err := new(Options).apply(key, opts) + if err != nil { + return nil, err + } + + if o.CertBuffer == nil { + return nil, errors.New("certificate template cannot be empty") + } + + // With templates + var cert Certificate + if err := json.NewDecoder(o.CertBuffer).Decode(&cert); err != nil { + return nil, errors.Wrap(err, "error unmarshaling certificate") + } + + // Complete with public key + cert.Key = key + + return &cert, nil +} + +func (c *Certificate) GetCertificate() *ssh.Certificate { + return &ssh.Certificate{ + Nonce: c.Nonce, + Key: c.Key, + Serial: c.Serial, + CertType: uint32(c.Type), + KeyId: c.KeyID, + ValidPrincipals: c.Principals, + ValidAfter: c.ValidAfter, + ValidBefore: c.ValidBefore, + Permissions: ssh.Permissions{ + CriticalOptions: c.CriticalOptions, + Extensions: c.Extensions, + }, + Reserved: c.Reserved, + } +} + +// CreateCertificate signs the given certificate with the given signer. If the +// certificate does not have a nonce or a serial, it will create random ones. +func CreateCertificate(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, error) { + if len(cert.Nonce) == 0 { + nonce, err := randutil.ASCII(32) + if err != nil { + return nil, err + } + cert.Nonce = []byte(nonce) + } + + if cert.Serial == 0 { + if err := binary.Read(rand.Reader, binary.BigEndian, &cert.Serial); err != nil { + return nil, errors.Wrap(err, "error reading random number") + } + } + + // Set signer public key. + cert.SignatureKey = signer.PublicKey() + + // Get bytes for signing trailing the signature length. + data := cert.Marshal() + data = data[:len(data)-4] + + // Sign the certificate. + sig, err := signer.Sign(rand.Reader, data) + if err != nil { + return nil, errors.Wrap(err, "error signing certificate") + } + cert.Signature = sig + + return cert, nil +} diff --git a/sshutil/certificate_test.go b/sshutil/certificate_test.go new file mode 100644 index 00000000..ed8c5f9a --- /dev/null +++ b/sshutil/certificate_test.go @@ -0,0 +1,347 @@ +package sshutil + +import ( + "bytes" + "crypto/ed25519" + "crypto/rand" + "encoding/base64" + "reflect" + "testing" + + "golang.org/x/crypto/ssh" +) + +func mustGenerateKey(t *testing.T) (ssh.PublicKey, ssh.Signer) { + t.Helper() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + key, err := ssh.NewPublicKey(pub) + if err != nil { + t.Fatal(err) + } + signer, err := ssh.NewSignerFromKey(priv) + if err != nil { + t.Fatal(err) + } + return key, signer +} + +func mustGeneratePublicKey(t *testing.T) ssh.PublicKey { + t.Helper() + key, _ := mustGenerateKey(t) + return key +} + +func TestNewCertificate(t *testing.T) { + key := mustGeneratePublicKey(t) + + type args struct { + key ssh.PublicKey + opts []Option + } + tests := []struct { + name string + args args + want *Certificate + wantErr bool + }{ + {"user", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(UserCert, "jane@doe.com", []string{"jane"}))}}, &Certificate{ + Nonce: nil, + Key: key, + Serial: 0, + Type: UserCert, + KeyID: "jane@doe.com", + Principals: []string{"jane"}, + ValidAfter: 0, + ValidBefore: 0, + CriticalOptions: nil, + Extensions: map[string]string{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + }, + Reserved: nil, + SignatureKey: nil, + Signature: nil, + }, false}, + {"host", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(HostCert, "foobar", []string{"foo.internal", "bar.internal"}))}}, &Certificate{ + Nonce: nil, + Key: key, + Serial: 0, + Type: HostCert, + KeyID: "foobar", + Principals: []string{"foo.internal", "bar.internal"}, + ValidAfter: 0, + ValidBefore: 0, + CriticalOptions: nil, + Extensions: nil, + Reserved: nil, + SignatureKey: nil, + Signature: nil, + }, false}, + {"file", args{key, []Option{WithTemplateFile("./testdata/github.tpl", TemplateData{ + TypeKey: UserCert, + KeyIDKey: "john@doe.com", + PrincipalsKey: []string{"john", "john@doe.com"}, + ExtensionsKey: DefaultExtensions(UserCert), + InsecureKey: map[string]interface{}{ + "User": map[string]interface{}{"username": "john"}, + }, + })}}, &Certificate{ + Nonce: nil, + Key: key, + Serial: 0, + Type: UserCert, + KeyID: "john@doe.com", + Principals: []string{"john", "john@doe.com"}, + ValidAfter: 0, + ValidBefore: 0, + CriticalOptions: nil, + Extensions: map[string]string{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + "login@github.com": "john", + }, + Reserved: nil, + SignatureKey: nil, + Signature: nil, + }, false}, + {"base64", args{key, []Option{WithTemplateBase64(base64.StdEncoding.EncodeToString([]byte(DefaultCertificate)), CreateTemplateData(HostCert, "foo.internal", nil))}}, &Certificate{ + Nonce: nil, + Key: key, + Serial: 0, + Type: HostCert, + KeyID: "foo.internal", + Principals: nil, + ValidAfter: 0, + ValidBefore: 0, + CriticalOptions: nil, + Extensions: nil, + Reserved: nil, + SignatureKey: nil, + Signature: nil, + }, false}, + {"failNilOptions", args{key, nil}, nil, true}, + {"failEmptyOptions", args{key, nil}, nil, true}, + {"badBase64Template", args{key, []Option{WithTemplateBase64("foobar", TemplateData{})}}, nil, true}, + {"badFileTemplate", args{key, []Option{WithTemplateFile("./testdata/missing.tpl", TemplateData{})}}, nil, true}, + {"badJsonTemplate", args{key, []Option{WithTemplate(`{"type":{{ .Type }}}`, TemplateData{})}}, nil, true}, + {"failTemplate", args{key, []Option{WithTemplate(`{{ fail "an error" }}`, TemplateData{})}}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewCertificate(tt.args.key, tt.args.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("NewCertificate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewCertificate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCertificate_GetCertificate(t *testing.T) { + key := mustGeneratePublicKey(t) + + type fields struct { + Nonce []byte + Key ssh.PublicKey + Serial uint64 + Type CertType + KeyID string + Principals []string + ValidAfter uint64 + ValidBefore uint64 + CriticalOptions map[string]string + Extensions map[string]string + Reserved []byte + SignatureKey ssh.PublicKey + Signature *ssh.Signature + } + tests := []struct { + name string + fields fields + want *ssh.Certificate + }{ + {"user", fields{ + Nonce: []byte("0123456789"), + Key: key, + Serial: 123, + Type: UserCert, + KeyID: "key-id", + Principals: []string{"john"}, + ValidAfter: 1111, + ValidBefore: 2222, + CriticalOptions: map[string]string{"foo": "bar"}, + Extensions: map[string]string{"login@github.com": "john"}, + Reserved: []byte("reserved"), + SignatureKey: key, + Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")}, + }, &ssh.Certificate{ + Nonce: []byte("0123456789"), + Key: key, + Serial: 123, + CertType: ssh.UserCert, + KeyId: "key-id", + ValidPrincipals: []string{"john"}, + ValidAfter: 1111, + ValidBefore: 2222, + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{"foo": "bar"}, + Extensions: map[string]string{"login@github.com": "john"}, + }, + Reserved: []byte("reserved"), + }}, + {"host", fields{ + Nonce: []byte("0123456789"), + Key: key, + Serial: 123, + Type: HostCert, + KeyID: "key-id", + Principals: []string{"foo.internal", "bar.internal"}, + ValidAfter: 1111, + ValidBefore: 2222, + CriticalOptions: map[string]string{"foo": "bar"}, + Extensions: nil, + Reserved: []byte("reserved"), + SignatureKey: key, + Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")}, + }, &ssh.Certificate{ + Nonce: []byte("0123456789"), + Key: key, + Serial: 123, + CertType: ssh.HostCert, + KeyId: "key-id", + ValidPrincipals: []string{"foo.internal", "bar.internal"}, + ValidAfter: 1111, + ValidBefore: 2222, + Permissions: ssh.Permissions{ + CriticalOptions: map[string]string{"foo": "bar"}, + Extensions: nil, + }, + Reserved: []byte("reserved"), + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Certificate{ + Nonce: tt.fields.Nonce, + Key: tt.fields.Key, + Serial: tt.fields.Serial, + Type: tt.fields.Type, + KeyID: tt.fields.KeyID, + Principals: tt.fields.Principals, + ValidAfter: tt.fields.ValidAfter, + ValidBefore: tt.fields.ValidBefore, + CriticalOptions: tt.fields.CriticalOptions, + Extensions: tt.fields.Extensions, + Reserved: tt.fields.Reserved, + SignatureKey: tt.fields.SignatureKey, + Signature: tt.fields.Signature, + } + if got := c.GetCertificate(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Certificate.GetCertificate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCreateCertificate(t *testing.T) { + key, signer := mustGenerateKey(t) + type args struct { + cert *ssh.Certificate + signer ssh.Signer + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok", args{&ssh.Certificate{ + Nonce: []byte("0123456789"), + Key: key, + Serial: 123, + CertType: ssh.HostCert, + KeyId: "foo", + ValidPrincipals: []string{"foo.internal"}, + ValidAfter: 1111, + ValidBefore: 2222, + Permissions: ssh.Permissions{}, + Reserved: []byte("reserved"), + }, signer}, false}, + {"emptyNonce", args{&ssh.Certificate{ + Key: key, + Serial: 123, + CertType: ssh.UserCert, + KeyId: "jane@doe.com", + ValidPrincipals: []string{"jane"}, + ValidAfter: 1111, + ValidBefore: 2222, + Permissions: ssh.Permissions{}, + Reserved: []byte("reserved"), + }, signer}, false}, + {"emptySerial", args{&ssh.Certificate{ + Nonce: []byte("0123456789"), + Key: key, + CertType: ssh.UserCert, + KeyId: "jane@doe.com", + ValidPrincipals: []string{"jane"}, + ValidAfter: 1111, + ValidBefore: 2222, + Permissions: ssh.Permissions{}, + Reserved: []byte("reserved"), + }, signer}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := CreateCertificate(tt.args.cert, tt.args.signer) + if (err != nil) != tt.wantErr { + t.Errorf("CreateCertificate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != nil { + switch { + case len(got.Nonce) == 0: + t.Errorf("CreateCertificate() nonce should not be empty") + case got.Serial == 0: + t.Errorf("CreateCertificate() serial should not be 0") + case got.Signature == nil: + t.Errorf("CreateCertificate() signature should not be nil") + case !bytes.Equal(got.SignatureKey.Marshal(), tt.args.signer.PublicKey().Marshal()): + t.Errorf("CreateCertificate() signature key is not the expected one") + } + + signature := got.Signature + got.Signature = nil + + data := got.Marshal() + data = data[:len(data)-4] + + sig, err := signer.Sign(rand.Reader, data) + if err != nil { + t.Errorf("signer.Sign() error = %v", err) + } + + // Verify signature + got.Signature = signature + if err := signer.PublicKey().Verify(data, got.Signature); err != nil { + t.Errorf("CreateCertificate() signature verify error = %v", err) + } + // Verify data with public key in cert + if err := got.Verify(data, sig); err != nil { + t.Errorf("CreateCertificate() certificate verify error = %v", err) + } + + } + }) + } +} diff --git a/sshutil/options.go b/sshutil/options.go new file mode 100644 index 00000000..58562ba9 --- /dev/null +++ b/sshutil/options.go @@ -0,0 +1,91 @@ +package sshutil + +import ( + "bytes" + "encoding/base64" + "io/ioutil" + "text/template" + + "github.com/Masterminds/sprig/v3" + "github.com/pkg/errors" + "github.com/smallstep/cli/config" + "golang.org/x/crypto/ssh" +) + +func getFuncMap(failMessage *string) template.FuncMap { + m := sprig.TxtFuncMap() + m["fail"] = func(msg string) (string, error) { + *failMessage = msg + return "", errors.New(msg) + } + return m +} + +// Options are the options that can be passed to NewCertificate. +type Options struct { + CertBuffer *bytes.Buffer +} + +func (o *Options) apply(key ssh.PublicKey, opts []Option) (*Options, error) { + for _, fn := range opts { + if err := fn(key, o); err != nil { + return o, err + } + } + return o, nil +} + +// Option is the type used as a variadic argument in NewCertificate. +type Option func(key ssh.PublicKey, o *Options) error + +// WithTemplate is an options that executes the given template text with the +// given data. +func WithTemplate(text string, data TemplateData) Option { + return func(key ssh.PublicKey, o *Options) error { + terr := new(TemplateError) + funcMap := getFuncMap(&terr.Message) + + tmpl, err := template.New("template").Funcs(funcMap).Parse(text) + if err != nil { + return errors.Wrapf(err, "error parsing template") + } + + buf := new(bytes.Buffer) + data.SetPublicKey(key) + if err := tmpl.Execute(buf, data); err != nil { + if terr.Message != "" { + return terr + } + return errors.Wrapf(err, "error executing template") + } + o.CertBuffer = buf + return nil + } +} + +// WithTemplateBase64 is an options that executes the given template base64 +// string with the given data. +func WithTemplateBase64(s string, data TemplateData) Option { + return func(key ssh.PublicKey, o *Options) error { + b, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return errors.Wrap(err, "error decoding template") + } + fn := WithTemplate(string(b), data) + return fn(key, o) + } +} + +// WithTemplateFile is an options that reads the template file and executes it +// with the given data. +func WithTemplateFile(path string, data TemplateData) Option { + return func(key ssh.PublicKey, o *Options) error { + filename := config.StepAbs(path) + b, err := ioutil.ReadFile(filename) + if err != nil { + return errors.Wrapf(err, "error reading %s", path) + } + fn := WithTemplate(string(b), data) + return fn(key, o) + } +} diff --git a/sshutil/options_test.go b/sshutil/options_test.go new file mode 100644 index 00000000..ac4da67d --- /dev/null +++ b/sshutil/options_test.go @@ -0,0 +1,176 @@ +package sshutil + +import ( + "bytes" + "encoding/base64" + "reflect" + "testing" + + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" +) + +func Test_getFuncMap_fail(t *testing.T) { + var failMesage string + fns := getFuncMap(&failMesage) + fail := fns["fail"].(func(s string) (string, error)) + s, err := fail("the fail message") + if err == nil { + t.Errorf("fail() error = %v, wantErr %v", err, errors.New("the fail message")) + } + if s != "" { + t.Errorf("fail() = \"%s\", want \"the fail message\"", s) + } + if failMesage != "the fail message" { + t.Errorf("fail() message = \"%s\", want \"the fail message\"", failMesage) + } +} + +func TestWithTemplate(t *testing.T) { + key := mustGeneratePublicKey(t) + + type args struct { + text string + data TemplateData + key ssh.PublicKey + } + tests := []struct { + name string + args args + want Options + wantErr bool + }{ + {"user", args{DefaultCertificate, TemplateData{ + TypeKey: "user", + KeyIDKey: "jane@doe.com", + PrincipalsKey: []string{"jane", "jane@doe.com"}, + ExtensionsKey: DefaultExtensions(UserCert), + }, key}, Options{ + CertBuffer: bytes.NewBufferString(`{ + "type": "user", + "keyId": "jane@doe.com", + "principals": ["jane","jane@doe.com"], + "extensions": {"permit-X11-forwarding":"","permit-agent-forwarding":"","permit-port-forwarding":"","permit-pty":"","permit-user-rc":""} +}`)}, false}, + {"host", args{DefaultCertificate, TemplateData{ + TypeKey: "host", + KeyIDKey: "foo", + PrincipalsKey: []string{"foo.internal"}, + }, key}, Options{ + CertBuffer: bytes.NewBufferString(`{ + "type": "host", + "keyId": "foo", + "principals": ["foo.internal"], + "extensions": null +}`)}, false}, + {"fail", args{`{{ fail "a message" }}`, TemplateData{}, key}, Options{}, true}, + {"failTemplate", args{`{{ fail "fatal error }}`, TemplateData{}, key}, Options{}, true}, + {"error", args{`{{ mustHas 3 .Data }}`, TemplateData{ + "Data": 3, + }, key}, Options{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got Options + fn := WithTemplate(tt.args.text, tt.args.data) + if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr { + t.Errorf("WithTemplate() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("WithTemplate() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWithTemplateBase64(t *testing.T) { + key := mustGeneratePublicKey(t) + + type args struct { + s string + data TemplateData + key ssh.PublicKey + } + tests := []struct { + name string + args args + want Options + wantErr bool + }{ + {"host", args{base64.StdEncoding.EncodeToString([]byte(DefaultCertificate)), TemplateData{ + TypeKey: "host", + KeyIDKey: "foo.internal", + PrincipalsKey: []string{"foo.internal", "bar.internal"}, + ExtensionsKey: map[string]interface{}{"foo": "bar"}, + }, key}, Options{ + CertBuffer: bytes.NewBufferString(`{ + "type": "host", + "keyId": "foo.internal", + "principals": ["foo.internal","bar.internal"], + "extensions": {"foo":"bar"} +}`)}, false}, + {"badBase64", args{"foobar", TemplateData{}, key}, Options{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got Options + fn := WithTemplateBase64(tt.args.s, tt.args.data) + if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr { + t.Errorf("WithTemplateBase64() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("WithTemplateBase64() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWithTemplateFile(t *testing.T) { + key := mustGeneratePublicKey(t) + + data := TemplateData{ + TypeKey: "user", + KeyIDKey: "jane@doe.com", + PrincipalsKey: []string{"jane", "jane@doe.com"}, + ExtensionsKey: DefaultExtensions(UserCert), + InsecureKey: map[string]interface{}{ + UserKey: map[string]interface{}{ + "username": "jane", + }, + }, + } + + type args struct { + path string + data TemplateData + key ssh.PublicKey + } + tests := []struct { + name string + args args + want Options + wantErr bool + }{ + {"github.com", args{"./testdata/github.tpl", data, key}, Options{ + CertBuffer: bytes.NewBufferString(`{ + "type": "user", + "keyId": "jane@doe.com", + "principals": ["jane","jane@doe.com"], + "extensions": {"login@github.com":"jane","permit-X11-forwarding":"","permit-agent-forwarding":"","permit-port-forwarding":"","permit-pty":"","permit-user-rc":""} +}`), + }, false}, + {"missing", args{"./testdata/missing.tpl", data, key}, Options{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got Options + fn := WithTemplateFile(tt.args.path, tt.args.data) + if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr { + t.Errorf("WithTemplateFile() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("WithTemplateFile() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sshutil/templates.go b/sshutil/templates.go new file mode 100644 index 00000000..c925da3a --- /dev/null +++ b/sshutil/templates.go @@ -0,0 +1,132 @@ +package sshutil + +import "golang.org/x/crypto/ssh" + +const ( + TypeKey = "Type" + KeyIDKey = "KeyID" + PrincipalsKey = "Principals" + ExtensionsKey = "Extensions" + TokenKey = "Token" + InsecureKey = "Insecure" + UserKey = "User" + PublicKey = "PublicKey" +) + +// TemplateError represents an error in a template produced by the fail +// function. +type TemplateError struct { + Message string +} + +// Error implements the error interface and returns the error string when a +// template executes the `fail "message"` function. +func (e *TemplateError) Error() string { + return e.Message +} + +// TemplateData is an alias for map[string]interface{}. It represents the data +// passed to the templates. +type TemplateData map[string]interface{} + +// CreateTemplateData returns a TemplateData with the given certificate type, +// key id, principals, and the default extensions. +func CreateTemplateData(ct CertType, keyID string, principals []string) TemplateData { + return TemplateData{ + TypeKey: ct.String(), + KeyIDKey: keyID, + PrincipalsKey: principals, + ExtensionsKey: DefaultExtensions(ct), + } +} + +// DefaultExtensions returns the default extensions set in an SSH certificate. +func DefaultExtensions(ct CertType) map[string]interface{} { + switch ct { + case UserCert: + return map[string]interface{}{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + } + default: + return nil + } +} + +// NewTemplateData creates a new map for templates data. +func NewTemplateData() TemplateData { + return TemplateData{} +} + +// AddExtension adds one extension to the templates data. +func (t TemplateData) AddExtension(key, value string) { + if m, ok := t[ExtensionsKey].(map[string]interface{}); ok { + m[key] = value + } else { + t[ExtensionsKey] = map[string]interface{}{ + key: value, + } + } +} + +// Set sets a key-value pair in the template data. +func (t TemplateData) Set(key string, v interface{}) { + t[key] = v +} + +// SetInsecure sets a key-value pair in the insecure template data. +func (t TemplateData) SetInsecure(key string, v interface{}) { + if m, ok := t[InsecureKey].(TemplateData); ok { + m[key] = v + } else { + t[InsecureKey] = TemplateData{key: v} + } +} + +// SetType sets the certificate type in the template data. +func (t TemplateData) SetType(typ CertType) { + t.Set(TypeKey, typ.String()) +} + +// SetType sets the certificate key id in the template data. +func (t TemplateData) SetKeyID(id string) { + t.Set(KeyIDKey, id) +} + +// SetPrincipals sets the certificate principals in the template data. +func (t TemplateData) SetPrincipals(p []string) { + t.Set(PrincipalsKey, p) +} + +// SetExtensions sets the certificate extensions in the template data. +func (t TemplateData) SetExtensions(e map[string]interface{}) { + t.Set(ExtensionsKey, e) +} + +// SetToken sets the given token in the template data. +func (t TemplateData) SetToken(v interface{}) { + t.Set(TokenKey, v) +} + +// SetUserData sets the given user provided object in the insecure template +// data. +func (t TemplateData) SetUserData(v interface{}) { + t.SetInsecure(UserKey, v) +} + +// SetUserData sets the given user provided object in the insecure template +// data. +func (t TemplateData) SetPublicKey(v ssh.PublicKey) { + t.Set(PublicKey, v) +} + +// DefaultCertificate is the default template for an SSH certificate. +const DefaultCertificate = `{ + "type": "{{ .Type }}", + "keyId": "{{ .KeyID }}", + "principals": {{ toJson .Principals }}, + "extensions": {{ toJson .Extensions }} +}` diff --git a/sshutil/templates_test.go b/sshutil/templates_test.go new file mode 100644 index 00000000..ad550ccb --- /dev/null +++ b/sshutil/templates_test.go @@ -0,0 +1,406 @@ +package sshutil + +import ( + "reflect" + "testing" + + "golang.org/x/crypto/ssh" +) + +func TestTemplateError_Error(t *testing.T) { + type fields struct { + Message string + } + tests := []struct { + name string + fields fields + want string + }{ + {"ok", fields{"message"}, "message"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &TemplateError{ + Message: tt.fields.Message, + } + if got := e.Error(); got != tt.want { + t.Errorf("TemplateError.Error() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCreateTemplateData(t *testing.T) { + type args struct { + ct CertType + keyID string + principals []string + } + tests := []struct { + name string + args args + want TemplateData + }{ + {"user", args{UserCert, "john@doe.com", []string{"john", "john@doe.com"}}, TemplateData{ + TypeKey: "user", + KeyIDKey: "john@doe.com", + PrincipalsKey: []string{"john", "john@doe.com"}, + ExtensionsKey: map[string]interface{}{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + }, + }}, + {"host", args{HostCert, "foo", []string{"foo.internal"}}, TemplateData{ + TypeKey: "host", + KeyIDKey: "foo", + PrincipalsKey: []string{"foo.internal"}, + ExtensionsKey: map[string]interface{}(nil), + }}, + {"other", args{100, "foo", []string{"foo.internal"}}, TemplateData{ + TypeKey: "", + KeyIDKey: "foo", + PrincipalsKey: []string{"foo.internal"}, + ExtensionsKey: map[string]interface{}(nil), + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CreateTemplateData(tt.args.ct, tt.args.keyID, tt.args.principals); !reflect.DeepEqual(got, tt.want) { + t.Errorf("CreateTemplateData() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDefaultExtensions(t *testing.T) { + type args struct { + ct CertType + } + tests := []struct { + name string + args args + want map[string]interface{} + }{ + {"user", args{UserCert}, map[string]interface{}{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + }}, + {"host", args{HostCert}, nil}, + {"other", args{100}, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := DefaultExtensions(tt.args.ct); !reflect.DeepEqual(got, tt.want) { + t.Errorf("DefaultExtensions() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewTemplateData(t *testing.T) { + tests := []struct { + name string + want TemplateData + }{ + {"ok", TemplateData{}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewTemplateData(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewTemplateData() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTemplateData_AddExtension(t *testing.T) { + type args struct { + key string + value string + } + tests := []struct { + name string + t TemplateData + args args + want TemplateData + }{ + {"empty", TemplateData{}, args{"key", "value"}, TemplateData{ + ExtensionsKey: map[string]interface{}{"key": "value"}, + }}, + {"overwrite", TemplateData{ + ExtensionsKey: map[string]interface{}{"key": "value"}, + }, args{"key", "value"}, TemplateData{ + ExtensionsKey: map[string]interface{}{ + "key": "value", + }, + }}, + {"add", TemplateData{ + ExtensionsKey: map[string]interface{}{"foo": "bar"}, + }, args{"key", "value"}, TemplateData{ + ExtensionsKey: map[string]interface{}{ + "key": "value", + "foo": "bar", + }, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.t.AddExtension(tt.args.key, tt.args.value) + if !reflect.DeepEqual(tt.t, tt.want) { + t.Errorf("AddExtension() = %v, want %v", tt.t, tt.want) + } + }) + } +} + +func TestTemplateData_Set(t *testing.T) { + type args struct { + key string + v interface{} + } + tests := []struct { + name string + t TemplateData + args args + want TemplateData + }{ + {"ok", TemplateData{}, args{"foo", "bar"}, TemplateData{ + "foo": "bar", + }}, + {"overwrite", TemplateData{}, args{"foo", "bar"}, TemplateData{ + "foo": "bar", + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.t.Set(tt.args.key, tt.args.v) + if !reflect.DeepEqual(tt.t, tt.want) { + t.Errorf("Set() = %v, want %v", tt.t, tt.want) + } + }) + } +} + +func TestTemplateData_SetInsecure(t *testing.T) { + type args struct { + key string + v interface{} + } + tests := []struct { + name string + td TemplateData + args args + want TemplateData + }{ + {"empty", TemplateData{}, args{"foo", "bar"}, TemplateData{InsecureKey: TemplateData{"foo": "bar"}}}, + {"overwrite", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"foo", "zar"}, TemplateData{InsecureKey: TemplateData{"foo": "zar"}}}, + {"add", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"bar", "foo"}, TemplateData{InsecureKey: TemplateData{"foo": "bar", "bar": "foo"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.td.SetInsecure(tt.args.key, tt.args.v) + if !reflect.DeepEqual(tt.td, tt.want) { + t.Errorf("TemplateData.SetInsecure() = %v, want %v", tt.td, tt.want) + } + }) + } +} + +func TestTemplateData_SetType(t *testing.T) { + type args struct { + typ CertType + } + tests := []struct { + name string + t TemplateData + args args + want TemplateData + }{ + {"user", TemplateData{}, args{UserCert}, TemplateData{ + TypeKey: "user", + }}, + {"host", TemplateData{}, args{HostCert}, TemplateData{ + TypeKey: "host", + }}, + {"overwrite", TemplateData{ + TypeKey: "host", + }, args{UserCert}, TemplateData{ + TypeKey: "user", + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.t.SetType(tt.args.typ) + if !reflect.DeepEqual(tt.t, tt.want) { + t.Errorf("SetType() = %v, want %v", tt.t, tt.want) + } + }) + } +} + +func TestTemplateData_SetKeyID(t *testing.T) { + type args struct { + id string + } + tests := []struct { + name string + t TemplateData + args args + want TemplateData + }{ + {"ok", TemplateData{}, args{"key-id"}, TemplateData{ + KeyIDKey: "key-id", + }}, + {"overwrite", TemplateData{}, args{"key-id-2"}, TemplateData{ + KeyIDKey: "key-id-2", + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.t.SetKeyID(tt.args.id) + if !reflect.DeepEqual(tt.t, tt.want) { + t.Errorf("SetKeyID() = %v, want %v", tt.t, tt.want) + } + }) + } +} + +func TestTemplateData_SetPrincipals(t *testing.T) { + type args struct { + p []string + } + tests := []struct { + name string + t TemplateData + args args + want TemplateData + }{ + {"ok", TemplateData{}, args{[]string{"jane"}}, TemplateData{ + PrincipalsKey: []string{"jane"}, + }}, + {"overwrite", TemplateData{}, args{[]string{"john", "john@doe.com"}}, TemplateData{ + PrincipalsKey: []string{"john", "john@doe.com"}, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.t.SetPrincipals(tt.args.p) + if !reflect.DeepEqual(tt.t, tt.want) { + t.Errorf("SetPrincipals() = %v, want %v", tt.t, tt.want) + } + }) + } +} + +func TestTemplateData_SetExtensions(t *testing.T) { + type args struct { + e map[string]interface{} + } + tests := []struct { + name string + t TemplateData + args args + want TemplateData + }{ + {"ok", TemplateData{}, args{map[string]interface{}{"foo": "bar"}}, TemplateData{ + ExtensionsKey: map[string]interface{}{"foo": "bar"}, + }}, + {"overwrite", TemplateData{ + ExtensionsKey: map[string]interface{}{"foo": "bar"}, + }, args{map[string]interface{}{"key": "value"}}, TemplateData{ + ExtensionsKey: map[string]interface{}{"key": "value"}, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.t.SetExtensions(tt.args.e) + if !reflect.DeepEqual(tt.t, tt.want) { + t.Errorf("SetExtensions() = %v, want %v", tt.t, tt.want) + } + }) + } +} + +func TestTemplateData_SetToken(t *testing.T) { + type args struct { + v interface{} + } + tests := []struct { + name string + td TemplateData + args args + want TemplateData + }{ + {"ok", TemplateData{}, args{"token"}, TemplateData{TokenKey: "token"}}, + {"overwrite", TemplateData{TokenKey: "foo"}, args{"token"}, TemplateData{TokenKey: "token"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.td.SetToken(tt.args.v) + if !reflect.DeepEqual(tt.td, tt.want) { + t.Errorf("TemplateData.SetToken() = %v, want %v", tt.td, tt.want) + } + }) + } +} + +func TestTemplateData_SetUserData(t *testing.T) { + type args struct { + v interface{} + } + tests := []struct { + name string + td TemplateData + args args + want TemplateData + }{ + {"ok", TemplateData{}, args{"userData"}, TemplateData{InsecureKey: TemplateData{UserKey: "userData"}}}, + {"overwrite", TemplateData{InsecureKey: TemplateData{UserKey: "foo"}}, args{"userData"}, TemplateData{InsecureKey: TemplateData{UserKey: "userData"}}}, + {"existing", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"userData"}, TemplateData{InsecureKey: TemplateData{"foo": "bar", UserKey: "userData"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.td.SetUserData(tt.args.v) + if !reflect.DeepEqual(tt.td, tt.want) { + t.Errorf("TemplateData.SetUserData() = %v, want %v", tt.td, tt.want) + } + }) + } +} + +func TestTemplateData_SetPublicKey(t *testing.T) { + k1 := mustGeneratePublicKey(t) + k2 := mustGeneratePublicKey(t) + type args struct { + v ssh.PublicKey + } + tests := []struct { + name string + t TemplateData + args args + want TemplateData + }{ + {"ok", TemplateData{}, args{k1}, TemplateData{ + PublicKey: k1, + }}, + {"overwrite", TemplateData{ + PublicKey: k1, + }, args{k2}, TemplateData{ + PublicKey: k2, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.t.SetPublicKey(tt.args.v) + if !reflect.DeepEqual(tt.t, tt.want) { + t.Errorf("TemplateData.SetPublicKey() = %v, want %v", tt.t, tt.want) + } + }) + } +} diff --git a/sshutil/testdata/github.tpl b/sshutil/testdata/github.tpl new file mode 100644 index 00000000..2340522f --- /dev/null +++ b/sshutil/testdata/github.tpl @@ -0,0 +1,10 @@ +{ + "type": "{{ .Type }}", + "keyId": "{{ .KeyID }}", + "principals": {{ toJson .Principals }}, +{{- if .Insecure.User.username }} + "extensions": {{ set .Extensions "login@github.com" .Insecure.User.username | toJson }} +{{- else }} + "extensions": {{ toJson .Extensions }} +{{- end }} +} \ No newline at end of file diff --git a/sshutil/types.go b/sshutil/types.go index 9300771b..bf62933b 100644 --- a/sshutil/types.go +++ b/sshutil/types.go @@ -1,5 +1,13 @@ package sshutil +import ( + "encoding/json" + "strings" + + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" +) + // Hosts are tagged with k,v pairs. These tags are how a user is ultimately // associated with a host. type HostTag struct { @@ -14,3 +22,60 @@ type Host struct { HostTags []HostTag `json:"host_tags"` Hostname string `json:"hostname"` } + +// CertType defines the certificate type, it can be a user or a host +// certificate. +type CertType uint32 + +const ( + // UserCert defines a user certificate. + UserCert CertType = ssh.UserCert + + // HostCert defines a host certificate. + HostCert CertType = ssh.HostCert +) + +const ( + userString = "user" + hostString = "host" +) + +// String returns "user" for user certificates and "host" for host certificates. +// It will return the empty string for any other value. +func (c CertType) String() string { + switch c { + case UserCert: + return userString + case HostCert: + return hostString + default: + return "" + } +} + +// MarshalJSON implements the json.Marshaler interface for CertType. UserCert +// will be marshaled as the string "user" and HostCert as "host". +func (c CertType) MarshalJSON() ([]byte, error) { + if s := c.String(); s != "" { + return []byte(`"` + s + `"`), nil + } + return nil, errors.Errorf("unknown certificate type %d", c) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for CertType. +func (c *CertType) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return errors.Wrap(err, "error unmarshaling certificate type") + } + switch strings.ToLower(s) { + case userString: + *c = UserCert + return nil + case hostString: + *c = HostCert + return nil + default: + return errors.Errorf("error unmarshaling '%s' as a certificate type", s) + } +} diff --git a/sshutil/types_test.go b/sshutil/types_test.go new file mode 100644 index 00000000..15306554 --- /dev/null +++ b/sshutil/types_test.go @@ -0,0 +1,82 @@ +package sshutil + +import ( + "reflect" + "testing" +) + +func TestCertType_String(t *testing.T) { + tests := []struct { + name string + c CertType + want string + }{ + {"user", UserCert, "user"}, + {"host", HostCert, "host"}, + {"empty", 100, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.c.String(); got != tt.want { + t.Errorf("CertType.String() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCertType_MarshalJSON(t *testing.T) { + tests := []struct { + name string + c CertType + want []byte + wantErr bool + }{ + {"user", UserCert, []byte(`"user"`), false}, + {"host", HostCert, []byte(`"host"`), false}, + {"error", 100, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.c.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("CertType.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CertType.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCertType_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + want CertType + wantErr bool + }{ + {"user", args{[]byte(`"user"`)}, UserCert, false}, + {"USER", args{[]byte(`"USER"`)}, UserCert, false}, + {"host", args{[]byte(`"host"`)}, HostCert, false}, + {"HosT", args{[]byte(`"HosT"`)}, HostCert, false}, + {" user ", args{[]byte(`" user "`)}, 0, true}, + {"number", args{[]byte(`1`)}, 0, true}, + {"object", args{[]byte(`{}`)}, 0, true}, + {"badJSON", args{[]byte(`"user`)}, 0, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ct CertType + if err := ct.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("CertType.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(ct, tt.want) { + t.Errorf("CertType.UnmarshalJSON() = %v, want %v", ct, tt.want) + } + }) + } +}