Add package to generate ssh certificate for templates.

This commit is contained in:
Mariano Cano 2020-07-24 17:08:32 -07:00
parent 3e80f41c19
commit af3eeb870e
9 changed files with 1413 additions and 0 deletions

104
sshutil/certificate.go Normal file
View file

@ -0,0 +1,104 @@
package sshutil
import (
"crypto/rand"
"encoding/binary"
"encoding/json"
"github.com/pkg/errors"
"github.com/smallstep/cli/crypto/randutil"
"golang.org/x/crypto/ssh"
)
// Certificate is the json representation of ssh.Certificate.
type Certificate struct {
Nonce []byte `json:"nonce"`
Key ssh.PublicKey `json:"-"`
Serial uint64 `json:"serial"`
Type CertType `json:"type"`
KeyID string `json:"keyId"`
Principals []string `json:"principals"`
ValidAfter uint64 `json:"-"`
ValidBefore uint64 `json:"-"`
CriticalOptions map[string]string `json:"criticalOptions"`
Extensions map[string]string `json:"extensions"`
Reserved []byte `json:"reserved"`
SignatureKey ssh.PublicKey `json:"-"`
Signature *ssh.Signature `json:"-"`
}
// NewCertificate creates a new certificate with the given key after parsing a
// template given in the options.
func NewCertificate(key ssh.PublicKey, opts ...Option) (*Certificate, error) {
o, err := new(Options).apply(key, opts)
if err != nil {
return nil, err
}
if o.CertBuffer == nil {
return nil, errors.New("certificate template cannot be empty")
}
// With templates
var cert Certificate
if err := json.NewDecoder(o.CertBuffer).Decode(&cert); err != nil {
return nil, errors.Wrap(err, "error unmarshaling certificate")
}
// Complete with public key
cert.Key = key
return &cert, nil
}
func (c *Certificate) GetCertificate() *ssh.Certificate {
return &ssh.Certificate{
Nonce: c.Nonce,
Key: c.Key,
Serial: c.Serial,
CertType: uint32(c.Type),
KeyId: c.KeyID,
ValidPrincipals: c.Principals,
ValidAfter: c.ValidAfter,
ValidBefore: c.ValidBefore,
Permissions: ssh.Permissions{
CriticalOptions: c.CriticalOptions,
Extensions: c.Extensions,
},
Reserved: c.Reserved,
}
}
// CreateCertificate signs the given certificate with the given signer. If the
// certificate does not have a nonce or a serial, it will create random ones.
func CreateCertificate(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, error) {
if len(cert.Nonce) == 0 {
nonce, err := randutil.ASCII(32)
if err != nil {
return nil, err
}
cert.Nonce = []byte(nonce)
}
if cert.Serial == 0 {
if err := binary.Read(rand.Reader, binary.BigEndian, &cert.Serial); err != nil {
return nil, errors.Wrap(err, "error reading random number")
}
}
// Set signer public key.
cert.SignatureKey = signer.PublicKey()
// Get bytes for signing trailing the signature length.
data := cert.Marshal()
data = data[:len(data)-4]
// Sign the certificate.
sig, err := signer.Sign(rand.Reader, data)
if err != nil {
return nil, errors.Wrap(err, "error signing certificate")
}
cert.Signature = sig
return cert, nil
}

347
sshutil/certificate_test.go Normal file
View file

@ -0,0 +1,347 @@
package sshutil
import (
"bytes"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"reflect"
"testing"
"golang.org/x/crypto/ssh"
)
func mustGenerateKey(t *testing.T) (ssh.PublicKey, ssh.Signer) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
key, err := ssh.NewPublicKey(pub)
if err != nil {
t.Fatal(err)
}
signer, err := ssh.NewSignerFromKey(priv)
if err != nil {
t.Fatal(err)
}
return key, signer
}
func mustGeneratePublicKey(t *testing.T) ssh.PublicKey {
t.Helper()
key, _ := mustGenerateKey(t)
return key
}
func TestNewCertificate(t *testing.T) {
key := mustGeneratePublicKey(t)
type args struct {
key ssh.PublicKey
opts []Option
}
tests := []struct {
name string
args args
want *Certificate
wantErr bool
}{
{"user", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(UserCert, "jane@doe.com", []string{"jane"}))}}, &Certificate{
Nonce: nil,
Key: key,
Serial: 0,
Type: UserCert,
KeyID: "jane@doe.com",
Principals: []string{"jane"},
ValidAfter: 0,
ValidBefore: 0,
CriticalOptions: nil,
Extensions: map[string]string{
"permit-X11-forwarding": "",
"permit-agent-forwarding": "",
"permit-port-forwarding": "",
"permit-pty": "",
"permit-user-rc": "",
},
Reserved: nil,
SignatureKey: nil,
Signature: nil,
}, false},
{"host", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(HostCert, "foobar", []string{"foo.internal", "bar.internal"}))}}, &Certificate{
Nonce: nil,
Key: key,
Serial: 0,
Type: HostCert,
KeyID: "foobar",
Principals: []string{"foo.internal", "bar.internal"},
ValidAfter: 0,
ValidBefore: 0,
CriticalOptions: nil,
Extensions: nil,
Reserved: nil,
SignatureKey: nil,
Signature: nil,
}, false},
{"file", args{key, []Option{WithTemplateFile("./testdata/github.tpl", TemplateData{
TypeKey: UserCert,
KeyIDKey: "john@doe.com",
PrincipalsKey: []string{"john", "john@doe.com"},
ExtensionsKey: DefaultExtensions(UserCert),
InsecureKey: map[string]interface{}{
"User": map[string]interface{}{"username": "john"},
},
})}}, &Certificate{
Nonce: nil,
Key: key,
Serial: 0,
Type: UserCert,
KeyID: "john@doe.com",
Principals: []string{"john", "john@doe.com"},
ValidAfter: 0,
ValidBefore: 0,
CriticalOptions: nil,
Extensions: map[string]string{
"permit-X11-forwarding": "",
"permit-agent-forwarding": "",
"permit-port-forwarding": "",
"permit-pty": "",
"permit-user-rc": "",
"login@github.com": "john",
},
Reserved: nil,
SignatureKey: nil,
Signature: nil,
}, false},
{"base64", args{key, []Option{WithTemplateBase64(base64.StdEncoding.EncodeToString([]byte(DefaultCertificate)), CreateTemplateData(HostCert, "foo.internal", nil))}}, &Certificate{
Nonce: nil,
Key: key,
Serial: 0,
Type: HostCert,
KeyID: "foo.internal",
Principals: nil,
ValidAfter: 0,
ValidBefore: 0,
CriticalOptions: nil,
Extensions: nil,
Reserved: nil,
SignatureKey: nil,
Signature: nil,
}, false},
{"failNilOptions", args{key, nil}, nil, true},
{"failEmptyOptions", args{key, nil}, nil, true},
{"badBase64Template", args{key, []Option{WithTemplateBase64("foobar", TemplateData{})}}, nil, true},
{"badFileTemplate", args{key, []Option{WithTemplateFile("./testdata/missing.tpl", TemplateData{})}}, nil, true},
{"badJsonTemplate", args{key, []Option{WithTemplate(`{"type":{{ .Type }}}`, TemplateData{})}}, nil, true},
{"failTemplate", args{key, []Option{WithTemplate(`{{ fail "an error" }}`, TemplateData{})}}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewCertificate(tt.args.key, tt.args.opts...)
if (err != nil) != tt.wantErr {
t.Errorf("NewCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewCertificate() = %v, want %v", got, tt.want)
}
})
}
}
func TestCertificate_GetCertificate(t *testing.T) {
key := mustGeneratePublicKey(t)
type fields struct {
Nonce []byte
Key ssh.PublicKey
Serial uint64
Type CertType
KeyID string
Principals []string
ValidAfter uint64
ValidBefore uint64
CriticalOptions map[string]string
Extensions map[string]string
Reserved []byte
SignatureKey ssh.PublicKey
Signature *ssh.Signature
}
tests := []struct {
name string
fields fields
want *ssh.Certificate
}{
{"user", fields{
Nonce: []byte("0123456789"),
Key: key,
Serial: 123,
Type: UserCert,
KeyID: "key-id",
Principals: []string{"john"},
ValidAfter: 1111,
ValidBefore: 2222,
CriticalOptions: map[string]string{"foo": "bar"},
Extensions: map[string]string{"login@github.com": "john"},
Reserved: []byte("reserved"),
SignatureKey: key,
Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")},
}, &ssh.Certificate{
Nonce: []byte("0123456789"),
Key: key,
Serial: 123,
CertType: ssh.UserCert,
KeyId: "key-id",
ValidPrincipals: []string{"john"},
ValidAfter: 1111,
ValidBefore: 2222,
Permissions: ssh.Permissions{
CriticalOptions: map[string]string{"foo": "bar"},
Extensions: map[string]string{"login@github.com": "john"},
},
Reserved: []byte("reserved"),
}},
{"host", fields{
Nonce: []byte("0123456789"),
Key: key,
Serial: 123,
Type: HostCert,
KeyID: "key-id",
Principals: []string{"foo.internal", "bar.internal"},
ValidAfter: 1111,
ValidBefore: 2222,
CriticalOptions: map[string]string{"foo": "bar"},
Extensions: nil,
Reserved: []byte("reserved"),
SignatureKey: key,
Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")},
}, &ssh.Certificate{
Nonce: []byte("0123456789"),
Key: key,
Serial: 123,
CertType: ssh.HostCert,
KeyId: "key-id",
ValidPrincipals: []string{"foo.internal", "bar.internal"},
ValidAfter: 1111,
ValidBefore: 2222,
Permissions: ssh.Permissions{
CriticalOptions: map[string]string{"foo": "bar"},
Extensions: nil,
},
Reserved: []byte("reserved"),
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Certificate{
Nonce: tt.fields.Nonce,
Key: tt.fields.Key,
Serial: tt.fields.Serial,
Type: tt.fields.Type,
KeyID: tt.fields.KeyID,
Principals: tt.fields.Principals,
ValidAfter: tt.fields.ValidAfter,
ValidBefore: tt.fields.ValidBefore,
CriticalOptions: tt.fields.CriticalOptions,
Extensions: tt.fields.Extensions,
Reserved: tt.fields.Reserved,
SignatureKey: tt.fields.SignatureKey,
Signature: tt.fields.Signature,
}
if got := c.GetCertificate(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Certificate.GetCertificate() = %v, want %v", got, tt.want)
}
})
}
}
func TestCreateCertificate(t *testing.T) {
key, signer := mustGenerateKey(t)
type args struct {
cert *ssh.Certificate
signer ssh.Signer
}
tests := []struct {
name string
args args
wantErr bool
}{
{"ok", args{&ssh.Certificate{
Nonce: []byte("0123456789"),
Key: key,
Serial: 123,
CertType: ssh.HostCert,
KeyId: "foo",
ValidPrincipals: []string{"foo.internal"},
ValidAfter: 1111,
ValidBefore: 2222,
Permissions: ssh.Permissions{},
Reserved: []byte("reserved"),
}, signer}, false},
{"emptyNonce", args{&ssh.Certificate{
Key: key,
Serial: 123,
CertType: ssh.UserCert,
KeyId: "jane@doe.com",
ValidPrincipals: []string{"jane"},
ValidAfter: 1111,
ValidBefore: 2222,
Permissions: ssh.Permissions{},
Reserved: []byte("reserved"),
}, signer}, false},
{"emptySerial", args{&ssh.Certificate{
Nonce: []byte("0123456789"),
Key: key,
CertType: ssh.UserCert,
KeyId: "jane@doe.com",
ValidPrincipals: []string{"jane"},
ValidAfter: 1111,
ValidBefore: 2222,
Permissions: ssh.Permissions{},
Reserved: []byte("reserved"),
}, signer}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CreateCertificate(tt.args.cert, tt.args.signer)
if (err != nil) != tt.wantErr {
t.Errorf("CreateCertificate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != nil {
switch {
case len(got.Nonce) == 0:
t.Errorf("CreateCertificate() nonce should not be empty")
case got.Serial == 0:
t.Errorf("CreateCertificate() serial should not be 0")
case got.Signature == nil:
t.Errorf("CreateCertificate() signature should not be nil")
case !bytes.Equal(got.SignatureKey.Marshal(), tt.args.signer.PublicKey().Marshal()):
t.Errorf("CreateCertificate() signature key is not the expected one")
}
signature := got.Signature
got.Signature = nil
data := got.Marshal()
data = data[:len(data)-4]
sig, err := signer.Sign(rand.Reader, data)
if err != nil {
t.Errorf("signer.Sign() error = %v", err)
}
// Verify signature
got.Signature = signature
if err := signer.PublicKey().Verify(data, got.Signature); err != nil {
t.Errorf("CreateCertificate() signature verify error = %v", err)
}
// Verify data with public key in cert
if err := got.Verify(data, sig); err != nil {
t.Errorf("CreateCertificate() certificate verify error = %v", err)
}
}
})
}
}

91
sshutil/options.go Normal file
View file

@ -0,0 +1,91 @@
package sshutil
import (
"bytes"
"encoding/base64"
"io/ioutil"
"text/template"
"github.com/Masterminds/sprig/v3"
"github.com/pkg/errors"
"github.com/smallstep/cli/config"
"golang.org/x/crypto/ssh"
)
func getFuncMap(failMessage *string) template.FuncMap {
m := sprig.TxtFuncMap()
m["fail"] = func(msg string) (string, error) {
*failMessage = msg
return "", errors.New(msg)
}
return m
}
// Options are the options that can be passed to NewCertificate.
type Options struct {
CertBuffer *bytes.Buffer
}
func (o *Options) apply(key ssh.PublicKey, opts []Option) (*Options, error) {
for _, fn := range opts {
if err := fn(key, o); err != nil {
return o, err
}
}
return o, nil
}
// Option is the type used as a variadic argument in NewCertificate.
type Option func(key ssh.PublicKey, o *Options) error
// WithTemplate is an options that executes the given template text with the
// given data.
func WithTemplate(text string, data TemplateData) Option {
return func(key ssh.PublicKey, o *Options) error {
terr := new(TemplateError)
funcMap := getFuncMap(&terr.Message)
tmpl, err := template.New("template").Funcs(funcMap).Parse(text)
if err != nil {
return errors.Wrapf(err, "error parsing template")
}
buf := new(bytes.Buffer)
data.SetPublicKey(key)
if err := tmpl.Execute(buf, data); err != nil {
if terr.Message != "" {
return terr
}
return errors.Wrapf(err, "error executing template")
}
o.CertBuffer = buf
return nil
}
}
// WithTemplateBase64 is an options that executes the given template base64
// string with the given data.
func WithTemplateBase64(s string, data TemplateData) Option {
return func(key ssh.PublicKey, o *Options) error {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return errors.Wrap(err, "error decoding template")
}
fn := WithTemplate(string(b), data)
return fn(key, o)
}
}
// WithTemplateFile is an options that reads the template file and executes it
// with the given data.
func WithTemplateFile(path string, data TemplateData) Option {
return func(key ssh.PublicKey, o *Options) error {
filename := config.StepAbs(path)
b, err := ioutil.ReadFile(filename)
if err != nil {
return errors.Wrapf(err, "error reading %s", path)
}
fn := WithTemplate(string(b), data)
return fn(key, o)
}
}

176
sshutil/options_test.go Normal file
View file

@ -0,0 +1,176 @@
package sshutil
import (
"bytes"
"encoding/base64"
"reflect"
"testing"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
)
func Test_getFuncMap_fail(t *testing.T) {
var failMesage string
fns := getFuncMap(&failMesage)
fail := fns["fail"].(func(s string) (string, error))
s, err := fail("the fail message")
if err == nil {
t.Errorf("fail() error = %v, wantErr %v", err, errors.New("the fail message"))
}
if s != "" {
t.Errorf("fail() = \"%s\", want \"the fail message\"", s)
}
if failMesage != "the fail message" {
t.Errorf("fail() message = \"%s\", want \"the fail message\"", failMesage)
}
}
func TestWithTemplate(t *testing.T) {
key := mustGeneratePublicKey(t)
type args struct {
text string
data TemplateData
key ssh.PublicKey
}
tests := []struct {
name string
args args
want Options
wantErr bool
}{
{"user", args{DefaultCertificate, TemplateData{
TypeKey: "user",
KeyIDKey: "jane@doe.com",
PrincipalsKey: []string{"jane", "jane@doe.com"},
ExtensionsKey: DefaultExtensions(UserCert),
}, key}, Options{
CertBuffer: bytes.NewBufferString(`{
"type": "user",
"keyId": "jane@doe.com",
"principals": ["jane","jane@doe.com"],
"extensions": {"permit-X11-forwarding":"","permit-agent-forwarding":"","permit-port-forwarding":"","permit-pty":"","permit-user-rc":""}
}`)}, false},
{"host", args{DefaultCertificate, TemplateData{
TypeKey: "host",
KeyIDKey: "foo",
PrincipalsKey: []string{"foo.internal"},
}, key}, Options{
CertBuffer: bytes.NewBufferString(`{
"type": "host",
"keyId": "foo",
"principals": ["foo.internal"],
"extensions": null
}`)}, false},
{"fail", args{`{{ fail "a message" }}`, TemplateData{}, key}, Options{}, true},
{"failTemplate", args{`{{ fail "fatal error }}`, TemplateData{}, key}, Options{}, true},
{"error", args{`{{ mustHas 3 .Data }}`, TemplateData{
"Data": 3,
}, key}, Options{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got Options
fn := WithTemplate(tt.args.text, tt.args.data)
if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr {
t.Errorf("WithTemplate() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("WithTemplate() = %v, want %v", got, tt.want)
}
})
}
}
func TestWithTemplateBase64(t *testing.T) {
key := mustGeneratePublicKey(t)
type args struct {
s string
data TemplateData
key ssh.PublicKey
}
tests := []struct {
name string
args args
want Options
wantErr bool
}{
{"host", args{base64.StdEncoding.EncodeToString([]byte(DefaultCertificate)), TemplateData{
TypeKey: "host",
KeyIDKey: "foo.internal",
PrincipalsKey: []string{"foo.internal", "bar.internal"},
ExtensionsKey: map[string]interface{}{"foo": "bar"},
}, key}, Options{
CertBuffer: bytes.NewBufferString(`{
"type": "host",
"keyId": "foo.internal",
"principals": ["foo.internal","bar.internal"],
"extensions": {"foo":"bar"}
}`)}, false},
{"badBase64", args{"foobar", TemplateData{}, key}, Options{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got Options
fn := WithTemplateBase64(tt.args.s, tt.args.data)
if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr {
t.Errorf("WithTemplateBase64() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("WithTemplateBase64() = %v, want %v", got, tt.want)
}
})
}
}
func TestWithTemplateFile(t *testing.T) {
key := mustGeneratePublicKey(t)
data := TemplateData{
TypeKey: "user",
KeyIDKey: "jane@doe.com",
PrincipalsKey: []string{"jane", "jane@doe.com"},
ExtensionsKey: DefaultExtensions(UserCert),
InsecureKey: map[string]interface{}{
UserKey: map[string]interface{}{
"username": "jane",
},
},
}
type args struct {
path string
data TemplateData
key ssh.PublicKey
}
tests := []struct {
name string
args args
want Options
wantErr bool
}{
{"github.com", args{"./testdata/github.tpl", data, key}, Options{
CertBuffer: bytes.NewBufferString(`{
"type": "user",
"keyId": "jane@doe.com",
"principals": ["jane","jane@doe.com"],
"extensions": {"login@github.com":"jane","permit-X11-forwarding":"","permit-agent-forwarding":"","permit-port-forwarding":"","permit-pty":"","permit-user-rc":""}
}`),
}, false},
{"missing", args{"./testdata/missing.tpl", data, key}, Options{}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got Options
fn := WithTemplateFile(tt.args.path, tt.args.data)
if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr {
t.Errorf("WithTemplateFile() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("WithTemplateFile() = %v, want %v", got, tt.want)
}
})
}
}

132
sshutil/templates.go Normal file
View file

@ -0,0 +1,132 @@
package sshutil
import "golang.org/x/crypto/ssh"
const (
TypeKey = "Type"
KeyIDKey = "KeyID"
PrincipalsKey = "Principals"
ExtensionsKey = "Extensions"
TokenKey = "Token"
InsecureKey = "Insecure"
UserKey = "User"
PublicKey = "PublicKey"
)
// TemplateError represents an error in a template produced by the fail
// function.
type TemplateError struct {
Message string
}
// Error implements the error interface and returns the error string when a
// template executes the `fail "message"` function.
func (e *TemplateError) Error() string {
return e.Message
}
// TemplateData is an alias for map[string]interface{}. It represents the data
// passed to the templates.
type TemplateData map[string]interface{}
// CreateTemplateData returns a TemplateData with the given certificate type,
// key id, principals, and the default extensions.
func CreateTemplateData(ct CertType, keyID string, principals []string) TemplateData {
return TemplateData{
TypeKey: ct.String(),
KeyIDKey: keyID,
PrincipalsKey: principals,
ExtensionsKey: DefaultExtensions(ct),
}
}
// DefaultExtensions returns the default extensions set in an SSH certificate.
func DefaultExtensions(ct CertType) map[string]interface{} {
switch ct {
case UserCert:
return map[string]interface{}{
"permit-X11-forwarding": "",
"permit-agent-forwarding": "",
"permit-port-forwarding": "",
"permit-pty": "",
"permit-user-rc": "",
}
default:
return nil
}
}
// NewTemplateData creates a new map for templates data.
func NewTemplateData() TemplateData {
return TemplateData{}
}
// AddExtension adds one extension to the templates data.
func (t TemplateData) AddExtension(key, value string) {
if m, ok := t[ExtensionsKey].(map[string]interface{}); ok {
m[key] = value
} else {
t[ExtensionsKey] = map[string]interface{}{
key: value,
}
}
}
// Set sets a key-value pair in the template data.
func (t TemplateData) Set(key string, v interface{}) {
t[key] = v
}
// SetInsecure sets a key-value pair in the insecure template data.
func (t TemplateData) SetInsecure(key string, v interface{}) {
if m, ok := t[InsecureKey].(TemplateData); ok {
m[key] = v
} else {
t[InsecureKey] = TemplateData{key: v}
}
}
// SetType sets the certificate type in the template data.
func (t TemplateData) SetType(typ CertType) {
t.Set(TypeKey, typ.String())
}
// SetType sets the certificate key id in the template data.
func (t TemplateData) SetKeyID(id string) {
t.Set(KeyIDKey, id)
}
// SetPrincipals sets the certificate principals in the template data.
func (t TemplateData) SetPrincipals(p []string) {
t.Set(PrincipalsKey, p)
}
// SetExtensions sets the certificate extensions in the template data.
func (t TemplateData) SetExtensions(e map[string]interface{}) {
t.Set(ExtensionsKey, e)
}
// SetToken sets the given token in the template data.
func (t TemplateData) SetToken(v interface{}) {
t.Set(TokenKey, v)
}
// SetUserData sets the given user provided object in the insecure template
// data.
func (t TemplateData) SetUserData(v interface{}) {
t.SetInsecure(UserKey, v)
}
// SetUserData sets the given user provided object in the insecure template
// data.
func (t TemplateData) SetPublicKey(v ssh.PublicKey) {
t.Set(PublicKey, v)
}
// DefaultCertificate is the default template for an SSH certificate.
const DefaultCertificate = `{
"type": "{{ .Type }}",
"keyId": "{{ .KeyID }}",
"principals": {{ toJson .Principals }},
"extensions": {{ toJson .Extensions }}
}`

406
sshutil/templates_test.go Normal file
View file

@ -0,0 +1,406 @@
package sshutil
import (
"reflect"
"testing"
"golang.org/x/crypto/ssh"
)
func TestTemplateError_Error(t *testing.T) {
type fields struct {
Message string
}
tests := []struct {
name string
fields fields
want string
}{
{"ok", fields{"message"}, "message"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &TemplateError{
Message: tt.fields.Message,
}
if got := e.Error(); got != tt.want {
t.Errorf("TemplateError.Error() = %v, want %v", got, tt.want)
}
})
}
}
func TestCreateTemplateData(t *testing.T) {
type args struct {
ct CertType
keyID string
principals []string
}
tests := []struct {
name string
args args
want TemplateData
}{
{"user", args{UserCert, "john@doe.com", []string{"john", "john@doe.com"}}, TemplateData{
TypeKey: "user",
KeyIDKey: "john@doe.com",
PrincipalsKey: []string{"john", "john@doe.com"},
ExtensionsKey: map[string]interface{}{
"permit-X11-forwarding": "",
"permit-agent-forwarding": "",
"permit-port-forwarding": "",
"permit-pty": "",
"permit-user-rc": "",
},
}},
{"host", args{HostCert, "foo", []string{"foo.internal"}}, TemplateData{
TypeKey: "host",
KeyIDKey: "foo",
PrincipalsKey: []string{"foo.internal"},
ExtensionsKey: map[string]interface{}(nil),
}},
{"other", args{100, "foo", []string{"foo.internal"}}, TemplateData{
TypeKey: "",
KeyIDKey: "foo",
PrincipalsKey: []string{"foo.internal"},
ExtensionsKey: map[string]interface{}(nil),
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := CreateTemplateData(tt.args.ct, tt.args.keyID, tt.args.principals); !reflect.DeepEqual(got, tt.want) {
t.Errorf("CreateTemplateData() = %v, want %v", got, tt.want)
}
})
}
}
func TestDefaultExtensions(t *testing.T) {
type args struct {
ct CertType
}
tests := []struct {
name string
args args
want map[string]interface{}
}{
{"user", args{UserCert}, map[string]interface{}{
"permit-X11-forwarding": "",
"permit-agent-forwarding": "",
"permit-port-forwarding": "",
"permit-pty": "",
"permit-user-rc": "",
}},
{"host", args{HostCert}, nil},
{"other", args{100}, nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := DefaultExtensions(tt.args.ct); !reflect.DeepEqual(got, tt.want) {
t.Errorf("DefaultExtensions() = %v, want %v", got, tt.want)
}
})
}
}
func TestNewTemplateData(t *testing.T) {
tests := []struct {
name string
want TemplateData
}{
{"ok", TemplateData{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewTemplateData(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewTemplateData() = %v, want %v", got, tt.want)
}
})
}
}
func TestTemplateData_AddExtension(t *testing.T) {
type args struct {
key string
value string
}
tests := []struct {
name string
t TemplateData
args args
want TemplateData
}{
{"empty", TemplateData{}, args{"key", "value"}, TemplateData{
ExtensionsKey: map[string]interface{}{"key": "value"},
}},
{"overwrite", TemplateData{
ExtensionsKey: map[string]interface{}{"key": "value"},
}, args{"key", "value"}, TemplateData{
ExtensionsKey: map[string]interface{}{
"key": "value",
},
}},
{"add", TemplateData{
ExtensionsKey: map[string]interface{}{"foo": "bar"},
}, args{"key", "value"}, TemplateData{
ExtensionsKey: map[string]interface{}{
"key": "value",
"foo": "bar",
},
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.t.AddExtension(tt.args.key, tt.args.value)
if !reflect.DeepEqual(tt.t, tt.want) {
t.Errorf("AddExtension() = %v, want %v", tt.t, tt.want)
}
})
}
}
func TestTemplateData_Set(t *testing.T) {
type args struct {
key string
v interface{}
}
tests := []struct {
name string
t TemplateData
args args
want TemplateData
}{
{"ok", TemplateData{}, args{"foo", "bar"}, TemplateData{
"foo": "bar",
}},
{"overwrite", TemplateData{}, args{"foo", "bar"}, TemplateData{
"foo": "bar",
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.t.Set(tt.args.key, tt.args.v)
if !reflect.DeepEqual(tt.t, tt.want) {
t.Errorf("Set() = %v, want %v", tt.t, tt.want)
}
})
}
}
func TestTemplateData_SetInsecure(t *testing.T) {
type args struct {
key string
v interface{}
}
tests := []struct {
name string
td TemplateData
args args
want TemplateData
}{
{"empty", TemplateData{}, args{"foo", "bar"}, TemplateData{InsecureKey: TemplateData{"foo": "bar"}}},
{"overwrite", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"foo", "zar"}, TemplateData{InsecureKey: TemplateData{"foo": "zar"}}},
{"add", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"bar", "foo"}, TemplateData{InsecureKey: TemplateData{"foo": "bar", "bar": "foo"}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.td.SetInsecure(tt.args.key, tt.args.v)
if !reflect.DeepEqual(tt.td, tt.want) {
t.Errorf("TemplateData.SetInsecure() = %v, want %v", tt.td, tt.want)
}
})
}
}
func TestTemplateData_SetType(t *testing.T) {
type args struct {
typ CertType
}
tests := []struct {
name string
t TemplateData
args args
want TemplateData
}{
{"user", TemplateData{}, args{UserCert}, TemplateData{
TypeKey: "user",
}},
{"host", TemplateData{}, args{HostCert}, TemplateData{
TypeKey: "host",
}},
{"overwrite", TemplateData{
TypeKey: "host",
}, args{UserCert}, TemplateData{
TypeKey: "user",
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.t.SetType(tt.args.typ)
if !reflect.DeepEqual(tt.t, tt.want) {
t.Errorf("SetType() = %v, want %v", tt.t, tt.want)
}
})
}
}
func TestTemplateData_SetKeyID(t *testing.T) {
type args struct {
id string
}
tests := []struct {
name string
t TemplateData
args args
want TemplateData
}{
{"ok", TemplateData{}, args{"key-id"}, TemplateData{
KeyIDKey: "key-id",
}},
{"overwrite", TemplateData{}, args{"key-id-2"}, TemplateData{
KeyIDKey: "key-id-2",
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.t.SetKeyID(tt.args.id)
if !reflect.DeepEqual(tt.t, tt.want) {
t.Errorf("SetKeyID() = %v, want %v", tt.t, tt.want)
}
})
}
}
func TestTemplateData_SetPrincipals(t *testing.T) {
type args struct {
p []string
}
tests := []struct {
name string
t TemplateData
args args
want TemplateData
}{
{"ok", TemplateData{}, args{[]string{"jane"}}, TemplateData{
PrincipalsKey: []string{"jane"},
}},
{"overwrite", TemplateData{}, args{[]string{"john", "john@doe.com"}}, TemplateData{
PrincipalsKey: []string{"john", "john@doe.com"},
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.t.SetPrincipals(tt.args.p)
if !reflect.DeepEqual(tt.t, tt.want) {
t.Errorf("SetPrincipals() = %v, want %v", tt.t, tt.want)
}
})
}
}
func TestTemplateData_SetExtensions(t *testing.T) {
type args struct {
e map[string]interface{}
}
tests := []struct {
name string
t TemplateData
args args
want TemplateData
}{
{"ok", TemplateData{}, args{map[string]interface{}{"foo": "bar"}}, TemplateData{
ExtensionsKey: map[string]interface{}{"foo": "bar"},
}},
{"overwrite", TemplateData{
ExtensionsKey: map[string]interface{}{"foo": "bar"},
}, args{map[string]interface{}{"key": "value"}}, TemplateData{
ExtensionsKey: map[string]interface{}{"key": "value"},
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.t.SetExtensions(tt.args.e)
if !reflect.DeepEqual(tt.t, tt.want) {
t.Errorf("SetExtensions() = %v, want %v", tt.t, tt.want)
}
})
}
}
func TestTemplateData_SetToken(t *testing.T) {
type args struct {
v interface{}
}
tests := []struct {
name string
td TemplateData
args args
want TemplateData
}{
{"ok", TemplateData{}, args{"token"}, TemplateData{TokenKey: "token"}},
{"overwrite", TemplateData{TokenKey: "foo"}, args{"token"}, TemplateData{TokenKey: "token"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.td.SetToken(tt.args.v)
if !reflect.DeepEqual(tt.td, tt.want) {
t.Errorf("TemplateData.SetToken() = %v, want %v", tt.td, tt.want)
}
})
}
}
func TestTemplateData_SetUserData(t *testing.T) {
type args struct {
v interface{}
}
tests := []struct {
name string
td TemplateData
args args
want TemplateData
}{
{"ok", TemplateData{}, args{"userData"}, TemplateData{InsecureKey: TemplateData{UserKey: "userData"}}},
{"overwrite", TemplateData{InsecureKey: TemplateData{UserKey: "foo"}}, args{"userData"}, TemplateData{InsecureKey: TemplateData{UserKey: "userData"}}},
{"existing", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"userData"}, TemplateData{InsecureKey: TemplateData{"foo": "bar", UserKey: "userData"}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.td.SetUserData(tt.args.v)
if !reflect.DeepEqual(tt.td, tt.want) {
t.Errorf("TemplateData.SetUserData() = %v, want %v", tt.td, tt.want)
}
})
}
}
func TestTemplateData_SetPublicKey(t *testing.T) {
k1 := mustGeneratePublicKey(t)
k2 := mustGeneratePublicKey(t)
type args struct {
v ssh.PublicKey
}
tests := []struct {
name string
t TemplateData
args args
want TemplateData
}{
{"ok", TemplateData{}, args{k1}, TemplateData{
PublicKey: k1,
}},
{"overwrite", TemplateData{
PublicKey: k1,
}, args{k2}, TemplateData{
PublicKey: k2,
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.t.SetPublicKey(tt.args.v)
if !reflect.DeepEqual(tt.t, tt.want) {
t.Errorf("TemplateData.SetPublicKey() = %v, want %v", tt.t, tt.want)
}
})
}
}

10
sshutil/testdata/github.tpl vendored Normal file
View file

@ -0,0 +1,10 @@
{
"type": "{{ .Type }}",
"keyId": "{{ .KeyID }}",
"principals": {{ toJson .Principals }},
{{- if .Insecure.User.username }}
"extensions": {{ set .Extensions "login@github.com" .Insecure.User.username | toJson }}
{{- else }}
"extensions": {{ toJson .Extensions }}
{{- end }}
}

View file

@ -1,5 +1,13 @@
package sshutil package sshutil
import (
"encoding/json"
"strings"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
)
// Hosts are tagged with k,v pairs. These tags are how a user is ultimately // Hosts are tagged with k,v pairs. These tags are how a user is ultimately
// associated with a host. // associated with a host.
type HostTag struct { type HostTag struct {
@ -14,3 +22,60 @@ type Host struct {
HostTags []HostTag `json:"host_tags"` HostTags []HostTag `json:"host_tags"`
Hostname string `json:"hostname"` Hostname string `json:"hostname"`
} }
// CertType defines the certificate type, it can be a user or a host
// certificate.
type CertType uint32
const (
// UserCert defines a user certificate.
UserCert CertType = ssh.UserCert
// HostCert defines a host certificate.
HostCert CertType = ssh.HostCert
)
const (
userString = "user"
hostString = "host"
)
// String returns "user" for user certificates and "host" for host certificates.
// It will return the empty string for any other value.
func (c CertType) String() string {
switch c {
case UserCert:
return userString
case HostCert:
return hostString
default:
return ""
}
}
// MarshalJSON implements the json.Marshaler interface for CertType. UserCert
// will be marshaled as the string "user" and HostCert as "host".
func (c CertType) MarshalJSON() ([]byte, error) {
if s := c.String(); s != "" {
return []byte(`"` + s + `"`), nil
}
return nil, errors.Errorf("unknown certificate type %d", c)
}
// UnmarshalJSON implements the json.Unmarshaler interface for CertType.
func (c *CertType) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return errors.Wrap(err, "error unmarshaling certificate type")
}
switch strings.ToLower(s) {
case userString:
*c = UserCert
return nil
case hostString:
*c = HostCert
return nil
default:
return errors.Errorf("error unmarshaling '%s' as a certificate type", s)
}
}

82
sshutil/types_test.go Normal file
View file

@ -0,0 +1,82 @@
package sshutil
import (
"reflect"
"testing"
)
func TestCertType_String(t *testing.T) {
tests := []struct {
name string
c CertType
want string
}{
{"user", UserCert, "user"},
{"host", HostCert, "host"},
{"empty", 100, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.c.String(); got != tt.want {
t.Errorf("CertType.String() = %v, want %v", got, tt.want)
}
})
}
}
func TestCertType_MarshalJSON(t *testing.T) {
tests := []struct {
name string
c CertType
want []byte
wantErr bool
}{
{"user", UserCert, []byte(`"user"`), false},
{"host", HostCert, []byte(`"host"`), false},
{"error", 100, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.c.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Errorf("CertType.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("CertType.MarshalJSON() = %v, want %v", got, tt.want)
}
})
}
}
func TestCertType_UnmarshalJSON(t *testing.T) {
type args struct {
data []byte
}
tests := []struct {
name string
args args
want CertType
wantErr bool
}{
{"user", args{[]byte(`"user"`)}, UserCert, false},
{"USER", args{[]byte(`"USER"`)}, UserCert, false},
{"host", args{[]byte(`"host"`)}, HostCert, false},
{"HosT", args{[]byte(`"HosT"`)}, HostCert, false},
{" user ", args{[]byte(`" user "`)}, 0, true},
{"number", args{[]byte(`1`)}, 0, true},
{"object", args{[]byte(`{}`)}, 0, true},
{"badJSON", args{[]byte(`"user`)}, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var ct CertType
if err := ct.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("CertType.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(ct, tt.want) {
t.Errorf("CertType.UnmarshalJSON() = %v, want %v", ct, tt.want)
}
})
}
}