Add package to generate ssh certificate for templates.
This commit is contained in:
parent
3e80f41c19
commit
af3eeb870e
9 changed files with 1413 additions and 0 deletions
104
sshutil/certificate.go
Normal file
104
sshutil/certificate.go
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
package sshutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/cli/crypto/randutil"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Certificate is the json representation of ssh.Certificate.
|
||||||
|
type Certificate struct {
|
||||||
|
Nonce []byte `json:"nonce"`
|
||||||
|
Key ssh.PublicKey `json:"-"`
|
||||||
|
Serial uint64 `json:"serial"`
|
||||||
|
Type CertType `json:"type"`
|
||||||
|
KeyID string `json:"keyId"`
|
||||||
|
Principals []string `json:"principals"`
|
||||||
|
ValidAfter uint64 `json:"-"`
|
||||||
|
ValidBefore uint64 `json:"-"`
|
||||||
|
CriticalOptions map[string]string `json:"criticalOptions"`
|
||||||
|
Extensions map[string]string `json:"extensions"`
|
||||||
|
Reserved []byte `json:"reserved"`
|
||||||
|
SignatureKey ssh.PublicKey `json:"-"`
|
||||||
|
Signature *ssh.Signature `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCertificate creates a new certificate with the given key after parsing a
|
||||||
|
// template given in the options.
|
||||||
|
func NewCertificate(key ssh.PublicKey, opts ...Option) (*Certificate, error) {
|
||||||
|
o, err := new(Options).apply(key, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.CertBuffer == nil {
|
||||||
|
return nil, errors.New("certificate template cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// With templates
|
||||||
|
var cert Certificate
|
||||||
|
if err := json.NewDecoder(o.CertBuffer).Decode(&cert); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error unmarshaling certificate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete with public key
|
||||||
|
cert.Key = key
|
||||||
|
|
||||||
|
return &cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Certificate) GetCertificate() *ssh.Certificate {
|
||||||
|
return &ssh.Certificate{
|
||||||
|
Nonce: c.Nonce,
|
||||||
|
Key: c.Key,
|
||||||
|
Serial: c.Serial,
|
||||||
|
CertType: uint32(c.Type),
|
||||||
|
KeyId: c.KeyID,
|
||||||
|
ValidPrincipals: c.Principals,
|
||||||
|
ValidAfter: c.ValidAfter,
|
||||||
|
ValidBefore: c.ValidBefore,
|
||||||
|
Permissions: ssh.Permissions{
|
||||||
|
CriticalOptions: c.CriticalOptions,
|
||||||
|
Extensions: c.Extensions,
|
||||||
|
},
|
||||||
|
Reserved: c.Reserved,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateCertificate signs the given certificate with the given signer. If the
|
||||||
|
// certificate does not have a nonce or a serial, it will create random ones.
|
||||||
|
func CreateCertificate(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, error) {
|
||||||
|
if len(cert.Nonce) == 0 {
|
||||||
|
nonce, err := randutil.ASCII(32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cert.Nonce = []byte(nonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cert.Serial == 0 {
|
||||||
|
if err := binary.Read(rand.Reader, binary.BigEndian, &cert.Serial); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error reading random number")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set signer public key.
|
||||||
|
cert.SignatureKey = signer.PublicKey()
|
||||||
|
|
||||||
|
// Get bytes for signing trailing the signature length.
|
||||||
|
data := cert.Marshal()
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
|
||||||
|
// Sign the certificate.
|
||||||
|
sig, err := signer.Sign(rand.Reader, data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "error signing certificate")
|
||||||
|
}
|
||||||
|
cert.Signature = sig
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
347
sshutil/certificate_test.go
Normal file
347
sshutil/certificate_test.go
Normal file
|
@ -0,0 +1,347 @@
|
||||||
|
package sshutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustGenerateKey(t *testing.T) (ssh.PublicKey, ssh.Signer) {
|
||||||
|
t.Helper()
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
key, err := ssh.NewPublicKey(pub)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
signer, err := ssh.NewSignerFromKey(priv)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return key, signer
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustGeneratePublicKey(t *testing.T) ssh.PublicKey {
|
||||||
|
t.Helper()
|
||||||
|
key, _ := mustGenerateKey(t)
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCertificate(t *testing.T) {
|
||||||
|
key := mustGeneratePublicKey(t)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
key ssh.PublicKey
|
||||||
|
opts []Option
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *Certificate
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"user", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(UserCert, "jane@doe.com", []string{"jane"}))}}, &Certificate{
|
||||||
|
Nonce: nil,
|
||||||
|
Key: key,
|
||||||
|
Serial: 0,
|
||||||
|
Type: UserCert,
|
||||||
|
KeyID: "jane@doe.com",
|
||||||
|
Principals: []string{"jane"},
|
||||||
|
ValidAfter: 0,
|
||||||
|
ValidBefore: 0,
|
||||||
|
CriticalOptions: nil,
|
||||||
|
Extensions: map[string]string{
|
||||||
|
"permit-X11-forwarding": "",
|
||||||
|
"permit-agent-forwarding": "",
|
||||||
|
"permit-port-forwarding": "",
|
||||||
|
"permit-pty": "",
|
||||||
|
"permit-user-rc": "",
|
||||||
|
},
|
||||||
|
Reserved: nil,
|
||||||
|
SignatureKey: nil,
|
||||||
|
Signature: nil,
|
||||||
|
}, false},
|
||||||
|
{"host", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(HostCert, "foobar", []string{"foo.internal", "bar.internal"}))}}, &Certificate{
|
||||||
|
Nonce: nil,
|
||||||
|
Key: key,
|
||||||
|
Serial: 0,
|
||||||
|
Type: HostCert,
|
||||||
|
KeyID: "foobar",
|
||||||
|
Principals: []string{"foo.internal", "bar.internal"},
|
||||||
|
ValidAfter: 0,
|
||||||
|
ValidBefore: 0,
|
||||||
|
CriticalOptions: nil,
|
||||||
|
Extensions: nil,
|
||||||
|
Reserved: nil,
|
||||||
|
SignatureKey: nil,
|
||||||
|
Signature: nil,
|
||||||
|
}, false},
|
||||||
|
{"file", args{key, []Option{WithTemplateFile("./testdata/github.tpl", TemplateData{
|
||||||
|
TypeKey: UserCert,
|
||||||
|
KeyIDKey: "john@doe.com",
|
||||||
|
PrincipalsKey: []string{"john", "john@doe.com"},
|
||||||
|
ExtensionsKey: DefaultExtensions(UserCert),
|
||||||
|
InsecureKey: map[string]interface{}{
|
||||||
|
"User": map[string]interface{}{"username": "john"},
|
||||||
|
},
|
||||||
|
})}}, &Certificate{
|
||||||
|
Nonce: nil,
|
||||||
|
Key: key,
|
||||||
|
Serial: 0,
|
||||||
|
Type: UserCert,
|
||||||
|
KeyID: "john@doe.com",
|
||||||
|
Principals: []string{"john", "john@doe.com"},
|
||||||
|
ValidAfter: 0,
|
||||||
|
ValidBefore: 0,
|
||||||
|
CriticalOptions: nil,
|
||||||
|
Extensions: map[string]string{
|
||||||
|
"permit-X11-forwarding": "",
|
||||||
|
"permit-agent-forwarding": "",
|
||||||
|
"permit-port-forwarding": "",
|
||||||
|
"permit-pty": "",
|
||||||
|
"permit-user-rc": "",
|
||||||
|
"login@github.com": "john",
|
||||||
|
},
|
||||||
|
Reserved: nil,
|
||||||
|
SignatureKey: nil,
|
||||||
|
Signature: nil,
|
||||||
|
}, false},
|
||||||
|
{"base64", args{key, []Option{WithTemplateBase64(base64.StdEncoding.EncodeToString([]byte(DefaultCertificate)), CreateTemplateData(HostCert, "foo.internal", nil))}}, &Certificate{
|
||||||
|
Nonce: nil,
|
||||||
|
Key: key,
|
||||||
|
Serial: 0,
|
||||||
|
Type: HostCert,
|
||||||
|
KeyID: "foo.internal",
|
||||||
|
Principals: nil,
|
||||||
|
ValidAfter: 0,
|
||||||
|
ValidBefore: 0,
|
||||||
|
CriticalOptions: nil,
|
||||||
|
Extensions: nil,
|
||||||
|
Reserved: nil,
|
||||||
|
SignatureKey: nil,
|
||||||
|
Signature: nil,
|
||||||
|
}, false},
|
||||||
|
{"failNilOptions", args{key, nil}, nil, true},
|
||||||
|
{"failEmptyOptions", args{key, nil}, nil, true},
|
||||||
|
{"badBase64Template", args{key, []Option{WithTemplateBase64("foobar", TemplateData{})}}, nil, true},
|
||||||
|
{"badFileTemplate", args{key, []Option{WithTemplateFile("./testdata/missing.tpl", TemplateData{})}}, nil, true},
|
||||||
|
{"badJsonTemplate", args{key, []Option{WithTemplate(`{"type":{{ .Type }}}`, TemplateData{})}}, nil, true},
|
||||||
|
{"failTemplate", args{key, []Option{WithTemplate(`{{ fail "an error" }}`, TemplateData{})}}, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := NewCertificate(tt.args.key, tt.args.opts...)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("NewCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("NewCertificate() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertificate_GetCertificate(t *testing.T) {
|
||||||
|
key := mustGeneratePublicKey(t)
|
||||||
|
|
||||||
|
type fields struct {
|
||||||
|
Nonce []byte
|
||||||
|
Key ssh.PublicKey
|
||||||
|
Serial uint64
|
||||||
|
Type CertType
|
||||||
|
KeyID string
|
||||||
|
Principals []string
|
||||||
|
ValidAfter uint64
|
||||||
|
ValidBefore uint64
|
||||||
|
CriticalOptions map[string]string
|
||||||
|
Extensions map[string]string
|
||||||
|
Reserved []byte
|
||||||
|
SignatureKey ssh.PublicKey
|
||||||
|
Signature *ssh.Signature
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want *ssh.Certificate
|
||||||
|
}{
|
||||||
|
{"user", fields{
|
||||||
|
Nonce: []byte("0123456789"),
|
||||||
|
Key: key,
|
||||||
|
Serial: 123,
|
||||||
|
Type: UserCert,
|
||||||
|
KeyID: "key-id",
|
||||||
|
Principals: []string{"john"},
|
||||||
|
ValidAfter: 1111,
|
||||||
|
ValidBefore: 2222,
|
||||||
|
CriticalOptions: map[string]string{"foo": "bar"},
|
||||||
|
Extensions: map[string]string{"login@github.com": "john"},
|
||||||
|
Reserved: []byte("reserved"),
|
||||||
|
SignatureKey: key,
|
||||||
|
Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")},
|
||||||
|
}, &ssh.Certificate{
|
||||||
|
Nonce: []byte("0123456789"),
|
||||||
|
Key: key,
|
||||||
|
Serial: 123,
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
KeyId: "key-id",
|
||||||
|
ValidPrincipals: []string{"john"},
|
||||||
|
ValidAfter: 1111,
|
||||||
|
ValidBefore: 2222,
|
||||||
|
Permissions: ssh.Permissions{
|
||||||
|
CriticalOptions: map[string]string{"foo": "bar"},
|
||||||
|
Extensions: map[string]string{"login@github.com": "john"},
|
||||||
|
},
|
||||||
|
Reserved: []byte("reserved"),
|
||||||
|
}},
|
||||||
|
{"host", fields{
|
||||||
|
Nonce: []byte("0123456789"),
|
||||||
|
Key: key,
|
||||||
|
Serial: 123,
|
||||||
|
Type: HostCert,
|
||||||
|
KeyID: "key-id",
|
||||||
|
Principals: []string{"foo.internal", "bar.internal"},
|
||||||
|
ValidAfter: 1111,
|
||||||
|
ValidBefore: 2222,
|
||||||
|
CriticalOptions: map[string]string{"foo": "bar"},
|
||||||
|
Extensions: nil,
|
||||||
|
Reserved: []byte("reserved"),
|
||||||
|
SignatureKey: key,
|
||||||
|
Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")},
|
||||||
|
}, &ssh.Certificate{
|
||||||
|
Nonce: []byte("0123456789"),
|
||||||
|
Key: key,
|
||||||
|
Serial: 123,
|
||||||
|
CertType: ssh.HostCert,
|
||||||
|
KeyId: "key-id",
|
||||||
|
ValidPrincipals: []string{"foo.internal", "bar.internal"},
|
||||||
|
ValidAfter: 1111,
|
||||||
|
ValidBefore: 2222,
|
||||||
|
Permissions: ssh.Permissions{
|
||||||
|
CriticalOptions: map[string]string{"foo": "bar"},
|
||||||
|
Extensions: nil,
|
||||||
|
},
|
||||||
|
Reserved: []byte("reserved"),
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &Certificate{
|
||||||
|
Nonce: tt.fields.Nonce,
|
||||||
|
Key: tt.fields.Key,
|
||||||
|
Serial: tt.fields.Serial,
|
||||||
|
Type: tt.fields.Type,
|
||||||
|
KeyID: tt.fields.KeyID,
|
||||||
|
Principals: tt.fields.Principals,
|
||||||
|
ValidAfter: tt.fields.ValidAfter,
|
||||||
|
ValidBefore: tt.fields.ValidBefore,
|
||||||
|
CriticalOptions: tt.fields.CriticalOptions,
|
||||||
|
Extensions: tt.fields.Extensions,
|
||||||
|
Reserved: tt.fields.Reserved,
|
||||||
|
SignatureKey: tt.fields.SignatureKey,
|
||||||
|
Signature: tt.fields.Signature,
|
||||||
|
}
|
||||||
|
if got := c.GetCertificate(); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Certificate.GetCertificate() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateCertificate(t *testing.T) {
|
||||||
|
key, signer := mustGenerateKey(t)
|
||||||
|
type args struct {
|
||||||
|
cert *ssh.Certificate
|
||||||
|
signer ssh.Signer
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"ok", args{&ssh.Certificate{
|
||||||
|
Nonce: []byte("0123456789"),
|
||||||
|
Key: key,
|
||||||
|
Serial: 123,
|
||||||
|
CertType: ssh.HostCert,
|
||||||
|
KeyId: "foo",
|
||||||
|
ValidPrincipals: []string{"foo.internal"},
|
||||||
|
ValidAfter: 1111,
|
||||||
|
ValidBefore: 2222,
|
||||||
|
Permissions: ssh.Permissions{},
|
||||||
|
Reserved: []byte("reserved"),
|
||||||
|
}, signer}, false},
|
||||||
|
{"emptyNonce", args{&ssh.Certificate{
|
||||||
|
Key: key,
|
||||||
|
Serial: 123,
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
KeyId: "jane@doe.com",
|
||||||
|
ValidPrincipals: []string{"jane"},
|
||||||
|
ValidAfter: 1111,
|
||||||
|
ValidBefore: 2222,
|
||||||
|
Permissions: ssh.Permissions{},
|
||||||
|
Reserved: []byte("reserved"),
|
||||||
|
}, signer}, false},
|
||||||
|
{"emptySerial", args{&ssh.Certificate{
|
||||||
|
Nonce: []byte("0123456789"),
|
||||||
|
Key: key,
|
||||||
|
CertType: ssh.UserCert,
|
||||||
|
KeyId: "jane@doe.com",
|
||||||
|
ValidPrincipals: []string{"jane"},
|
||||||
|
ValidAfter: 1111,
|
||||||
|
ValidBefore: 2222,
|
||||||
|
Permissions: ssh.Permissions{},
|
||||||
|
Reserved: []byte("reserved"),
|
||||||
|
}, signer}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := CreateCertificate(tt.args.cert, tt.args.signer)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("CreateCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
switch {
|
||||||
|
case len(got.Nonce) == 0:
|
||||||
|
t.Errorf("CreateCertificate() nonce should not be empty")
|
||||||
|
case got.Serial == 0:
|
||||||
|
t.Errorf("CreateCertificate() serial should not be 0")
|
||||||
|
case got.Signature == nil:
|
||||||
|
t.Errorf("CreateCertificate() signature should not be nil")
|
||||||
|
case !bytes.Equal(got.SignatureKey.Marshal(), tt.args.signer.PublicKey().Marshal()):
|
||||||
|
t.Errorf("CreateCertificate() signature key is not the expected one")
|
||||||
|
}
|
||||||
|
|
||||||
|
signature := got.Signature
|
||||||
|
got.Signature = nil
|
||||||
|
|
||||||
|
data := got.Marshal()
|
||||||
|
data = data[:len(data)-4]
|
||||||
|
|
||||||
|
sig, err := signer.Sign(rand.Reader, data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("signer.Sign() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify signature
|
||||||
|
got.Signature = signature
|
||||||
|
if err := signer.PublicKey().Verify(data, got.Signature); err != nil {
|
||||||
|
t.Errorf("CreateCertificate() signature verify error = %v", err)
|
||||||
|
}
|
||||||
|
// Verify data with public key in cert
|
||||||
|
if err := got.Verify(data, sig); err != nil {
|
||||||
|
t.Errorf("CreateCertificate() certificate verify error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
91
sshutil/options.go
Normal file
91
sshutil/options.go
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
package sshutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"io/ioutil"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"github.com/Masterminds/sprig/v3"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/smallstep/cli/config"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getFuncMap(failMessage *string) template.FuncMap {
|
||||||
|
m := sprig.TxtFuncMap()
|
||||||
|
m["fail"] = func(msg string) (string, error) {
|
||||||
|
*failMessage = msg
|
||||||
|
return "", errors.New(msg)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options are the options that can be passed to NewCertificate.
|
||||||
|
type Options struct {
|
||||||
|
CertBuffer *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) apply(key ssh.PublicKey, opts []Option) (*Options, error) {
|
||||||
|
for _, fn := range opts {
|
||||||
|
if err := fn(key, o); err != nil {
|
||||||
|
return o, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return o, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option is the type used as a variadic argument in NewCertificate.
|
||||||
|
type Option func(key ssh.PublicKey, o *Options) error
|
||||||
|
|
||||||
|
// WithTemplate is an options that executes the given template text with the
|
||||||
|
// given data.
|
||||||
|
func WithTemplate(text string, data TemplateData) Option {
|
||||||
|
return func(key ssh.PublicKey, o *Options) error {
|
||||||
|
terr := new(TemplateError)
|
||||||
|
funcMap := getFuncMap(&terr.Message)
|
||||||
|
|
||||||
|
tmpl, err := template.New("template").Funcs(funcMap).Parse(text)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrapf(err, "error parsing template")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
data.SetPublicKey(key)
|
||||||
|
if err := tmpl.Execute(buf, data); err != nil {
|
||||||
|
if terr.Message != "" {
|
||||||
|
return terr
|
||||||
|
}
|
||||||
|
return errors.Wrapf(err, "error executing template")
|
||||||
|
}
|
||||||
|
o.CertBuffer = buf
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTemplateBase64 is an options that executes the given template base64
|
||||||
|
// string with the given data.
|
||||||
|
func WithTemplateBase64(s string, data TemplateData) Option {
|
||||||
|
return func(key ssh.PublicKey, o *Options) error {
|
||||||
|
b, err := base64.StdEncoding.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "error decoding template")
|
||||||
|
}
|
||||||
|
fn := WithTemplate(string(b), data)
|
||||||
|
return fn(key, o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTemplateFile is an options that reads the template file and executes it
|
||||||
|
// with the given data.
|
||||||
|
func WithTemplateFile(path string, data TemplateData) Option {
|
||||||
|
return func(key ssh.PublicKey, o *Options) error {
|
||||||
|
filename := config.StepAbs(path)
|
||||||
|
b, err := ioutil.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrapf(err, "error reading %s", path)
|
||||||
|
}
|
||||||
|
fn := WithTemplate(string(b), data)
|
||||||
|
return fn(key, o)
|
||||||
|
}
|
||||||
|
}
|
176
sshutil/options_test.go
Normal file
176
sshutil/options_test.go
Normal file
|
@ -0,0 +1,176 @@
|
||||||
|
package sshutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_getFuncMap_fail(t *testing.T) {
|
||||||
|
var failMesage string
|
||||||
|
fns := getFuncMap(&failMesage)
|
||||||
|
fail := fns["fail"].(func(s string) (string, error))
|
||||||
|
s, err := fail("the fail message")
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("fail() error = %v, wantErr %v", err, errors.New("the fail message"))
|
||||||
|
}
|
||||||
|
if s != "" {
|
||||||
|
t.Errorf("fail() = \"%s\", want \"the fail message\"", s)
|
||||||
|
}
|
||||||
|
if failMesage != "the fail message" {
|
||||||
|
t.Errorf("fail() message = \"%s\", want \"the fail message\"", failMesage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTemplate(t *testing.T) {
|
||||||
|
key := mustGeneratePublicKey(t)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
text string
|
||||||
|
data TemplateData
|
||||||
|
key ssh.PublicKey
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want Options
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"user", args{DefaultCertificate, TemplateData{
|
||||||
|
TypeKey: "user",
|
||||||
|
KeyIDKey: "jane@doe.com",
|
||||||
|
PrincipalsKey: []string{"jane", "jane@doe.com"},
|
||||||
|
ExtensionsKey: DefaultExtensions(UserCert),
|
||||||
|
}, key}, Options{
|
||||||
|
CertBuffer: bytes.NewBufferString(`{
|
||||||
|
"type": "user",
|
||||||
|
"keyId": "jane@doe.com",
|
||||||
|
"principals": ["jane","jane@doe.com"],
|
||||||
|
"extensions": {"permit-X11-forwarding":"","permit-agent-forwarding":"","permit-port-forwarding":"","permit-pty":"","permit-user-rc":""}
|
||||||
|
}`)}, false},
|
||||||
|
{"host", args{DefaultCertificate, TemplateData{
|
||||||
|
TypeKey: "host",
|
||||||
|
KeyIDKey: "foo",
|
||||||
|
PrincipalsKey: []string{"foo.internal"},
|
||||||
|
}, key}, Options{
|
||||||
|
CertBuffer: bytes.NewBufferString(`{
|
||||||
|
"type": "host",
|
||||||
|
"keyId": "foo",
|
||||||
|
"principals": ["foo.internal"],
|
||||||
|
"extensions": null
|
||||||
|
}`)}, false},
|
||||||
|
{"fail", args{`{{ fail "a message" }}`, TemplateData{}, key}, Options{}, true},
|
||||||
|
{"failTemplate", args{`{{ fail "fatal error }}`, TemplateData{}, key}, Options{}, true},
|
||||||
|
{"error", args{`{{ mustHas 3 .Data }}`, TemplateData{
|
||||||
|
"Data": 3,
|
||||||
|
}, key}, Options{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var got Options
|
||||||
|
fn := WithTemplate(tt.args.text, tt.args.data)
|
||||||
|
if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("WithTemplate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("WithTemplate() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTemplateBase64(t *testing.T) {
|
||||||
|
key := mustGeneratePublicKey(t)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
s string
|
||||||
|
data TemplateData
|
||||||
|
key ssh.PublicKey
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want Options
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"host", args{base64.StdEncoding.EncodeToString([]byte(DefaultCertificate)), TemplateData{
|
||||||
|
TypeKey: "host",
|
||||||
|
KeyIDKey: "foo.internal",
|
||||||
|
PrincipalsKey: []string{"foo.internal", "bar.internal"},
|
||||||
|
ExtensionsKey: map[string]interface{}{"foo": "bar"},
|
||||||
|
}, key}, Options{
|
||||||
|
CertBuffer: bytes.NewBufferString(`{
|
||||||
|
"type": "host",
|
||||||
|
"keyId": "foo.internal",
|
||||||
|
"principals": ["foo.internal","bar.internal"],
|
||||||
|
"extensions": {"foo":"bar"}
|
||||||
|
}`)}, false},
|
||||||
|
{"badBase64", args{"foobar", TemplateData{}, key}, Options{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var got Options
|
||||||
|
fn := WithTemplateBase64(tt.args.s, tt.args.data)
|
||||||
|
if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("WithTemplateBase64() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("WithTemplateBase64() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithTemplateFile(t *testing.T) {
|
||||||
|
key := mustGeneratePublicKey(t)
|
||||||
|
|
||||||
|
data := TemplateData{
|
||||||
|
TypeKey: "user",
|
||||||
|
KeyIDKey: "jane@doe.com",
|
||||||
|
PrincipalsKey: []string{"jane", "jane@doe.com"},
|
||||||
|
ExtensionsKey: DefaultExtensions(UserCert),
|
||||||
|
InsecureKey: map[string]interface{}{
|
||||||
|
UserKey: map[string]interface{}{
|
||||||
|
"username": "jane",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
path string
|
||||||
|
data TemplateData
|
||||||
|
key ssh.PublicKey
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want Options
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"github.com", args{"./testdata/github.tpl", data, key}, Options{
|
||||||
|
CertBuffer: bytes.NewBufferString(`{
|
||||||
|
"type": "user",
|
||||||
|
"keyId": "jane@doe.com",
|
||||||
|
"principals": ["jane","jane@doe.com"],
|
||||||
|
"extensions": {"login@github.com":"jane","permit-X11-forwarding":"","permit-agent-forwarding":"","permit-port-forwarding":"","permit-pty":"","permit-user-rc":""}
|
||||||
|
}`),
|
||||||
|
}, false},
|
||||||
|
{"missing", args{"./testdata/missing.tpl", data, key}, Options{}, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var got Options
|
||||||
|
fn := WithTemplateFile(tt.args.path, tt.args.data)
|
||||||
|
if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("WithTemplateFile() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("WithTemplateFile() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
132
sshutil/templates.go
Normal file
132
sshutil/templates.go
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
package sshutil
|
||||||
|
|
||||||
|
import "golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeKey = "Type"
|
||||||
|
KeyIDKey = "KeyID"
|
||||||
|
PrincipalsKey = "Principals"
|
||||||
|
ExtensionsKey = "Extensions"
|
||||||
|
TokenKey = "Token"
|
||||||
|
InsecureKey = "Insecure"
|
||||||
|
UserKey = "User"
|
||||||
|
PublicKey = "PublicKey"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TemplateError represents an error in a template produced by the fail
|
||||||
|
// function.
|
||||||
|
type TemplateError struct {
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error implements the error interface and returns the error string when a
|
||||||
|
// template executes the `fail "message"` function.
|
||||||
|
func (e *TemplateError) Error() string {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// TemplateData is an alias for map[string]interface{}. It represents the data
|
||||||
|
// passed to the templates.
|
||||||
|
type TemplateData map[string]interface{}
|
||||||
|
|
||||||
|
// CreateTemplateData returns a TemplateData with the given certificate type,
|
||||||
|
// key id, principals, and the default extensions.
|
||||||
|
func CreateTemplateData(ct CertType, keyID string, principals []string) TemplateData {
|
||||||
|
return TemplateData{
|
||||||
|
TypeKey: ct.String(),
|
||||||
|
KeyIDKey: keyID,
|
||||||
|
PrincipalsKey: principals,
|
||||||
|
ExtensionsKey: DefaultExtensions(ct),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultExtensions returns the default extensions set in an SSH certificate.
|
||||||
|
func DefaultExtensions(ct CertType) map[string]interface{} {
|
||||||
|
switch ct {
|
||||||
|
case UserCert:
|
||||||
|
return map[string]interface{}{
|
||||||
|
"permit-X11-forwarding": "",
|
||||||
|
"permit-agent-forwarding": "",
|
||||||
|
"permit-port-forwarding": "",
|
||||||
|
"permit-pty": "",
|
||||||
|
"permit-user-rc": "",
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTemplateData creates a new map for templates data.
|
||||||
|
func NewTemplateData() TemplateData {
|
||||||
|
return TemplateData{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddExtension adds one extension to the templates data.
|
||||||
|
func (t TemplateData) AddExtension(key, value string) {
|
||||||
|
if m, ok := t[ExtensionsKey].(map[string]interface{}); ok {
|
||||||
|
m[key] = value
|
||||||
|
} else {
|
||||||
|
t[ExtensionsKey] = map[string]interface{}{
|
||||||
|
key: value,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets a key-value pair in the template data.
|
||||||
|
func (t TemplateData) Set(key string, v interface{}) {
|
||||||
|
t[key] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetInsecure sets a key-value pair in the insecure template data.
|
||||||
|
func (t TemplateData) SetInsecure(key string, v interface{}) {
|
||||||
|
if m, ok := t[InsecureKey].(TemplateData); ok {
|
||||||
|
m[key] = v
|
||||||
|
} else {
|
||||||
|
t[InsecureKey] = TemplateData{key: v}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetType sets the certificate type in the template data.
|
||||||
|
func (t TemplateData) SetType(typ CertType) {
|
||||||
|
t.Set(TypeKey, typ.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetType sets the certificate key id in the template data.
|
||||||
|
func (t TemplateData) SetKeyID(id string) {
|
||||||
|
t.Set(KeyIDKey, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPrincipals sets the certificate principals in the template data.
|
||||||
|
func (t TemplateData) SetPrincipals(p []string) {
|
||||||
|
t.Set(PrincipalsKey, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExtensions sets the certificate extensions in the template data.
|
||||||
|
func (t TemplateData) SetExtensions(e map[string]interface{}) {
|
||||||
|
t.Set(ExtensionsKey, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetToken sets the given token in the template data.
|
||||||
|
func (t TemplateData) SetToken(v interface{}) {
|
||||||
|
t.Set(TokenKey, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserData sets the given user provided object in the insecure template
|
||||||
|
// data.
|
||||||
|
func (t TemplateData) SetUserData(v interface{}) {
|
||||||
|
t.SetInsecure(UserKey, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserData sets the given user provided object in the insecure template
|
||||||
|
// data.
|
||||||
|
func (t TemplateData) SetPublicKey(v ssh.PublicKey) {
|
||||||
|
t.Set(PublicKey, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultCertificate is the default template for an SSH certificate.
|
||||||
|
const DefaultCertificate = `{
|
||||||
|
"type": "{{ .Type }}",
|
||||||
|
"keyId": "{{ .KeyID }}",
|
||||||
|
"principals": {{ toJson .Principals }},
|
||||||
|
"extensions": {{ toJson .Extensions }}
|
||||||
|
}`
|
406
sshutil/templates_test.go
Normal file
406
sshutil/templates_test.go
Normal file
|
@ -0,0 +1,406 @@
|
||||||
|
package sshutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTemplateError_Error(t *testing.T) {
|
||||||
|
type fields struct {
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fields fields
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"ok", fields{"message"}, "message"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
e := &TemplateError{
|
||||||
|
Message: tt.fields.Message,
|
||||||
|
}
|
||||||
|
if got := e.Error(); got != tt.want {
|
||||||
|
t.Errorf("TemplateError.Error() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateTemplateData(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
ct CertType
|
||||||
|
keyID string
|
||||||
|
principals []string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"user", args{UserCert, "john@doe.com", []string{"john", "john@doe.com"}}, TemplateData{
|
||||||
|
TypeKey: "user",
|
||||||
|
KeyIDKey: "john@doe.com",
|
||||||
|
PrincipalsKey: []string{"john", "john@doe.com"},
|
||||||
|
ExtensionsKey: map[string]interface{}{
|
||||||
|
"permit-X11-forwarding": "",
|
||||||
|
"permit-agent-forwarding": "",
|
||||||
|
"permit-port-forwarding": "",
|
||||||
|
"permit-pty": "",
|
||||||
|
"permit-user-rc": "",
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
{"host", args{HostCert, "foo", []string{"foo.internal"}}, TemplateData{
|
||||||
|
TypeKey: "host",
|
||||||
|
KeyIDKey: "foo",
|
||||||
|
PrincipalsKey: []string{"foo.internal"},
|
||||||
|
ExtensionsKey: map[string]interface{}(nil),
|
||||||
|
}},
|
||||||
|
{"other", args{100, "foo", []string{"foo.internal"}}, TemplateData{
|
||||||
|
TypeKey: "",
|
||||||
|
KeyIDKey: "foo",
|
||||||
|
PrincipalsKey: []string{"foo.internal"},
|
||||||
|
ExtensionsKey: map[string]interface{}(nil),
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := CreateTemplateData(tt.args.ct, tt.args.keyID, tt.args.principals); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("CreateTemplateData() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultExtensions(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
ct CertType
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want map[string]interface{}
|
||||||
|
}{
|
||||||
|
{"user", args{UserCert}, map[string]interface{}{
|
||||||
|
"permit-X11-forwarding": "",
|
||||||
|
"permit-agent-forwarding": "",
|
||||||
|
"permit-port-forwarding": "",
|
||||||
|
"permit-pty": "",
|
||||||
|
"permit-user-rc": "",
|
||||||
|
}},
|
||||||
|
{"host", args{HostCert}, nil},
|
||||||
|
{"other", args{100}, nil},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := DefaultExtensions(tt.args.ct); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("DefaultExtensions() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewTemplateData(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"ok", TemplateData{}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := NewTemplateData(); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("NewTemplateData() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateData_AddExtension(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
key string
|
||||||
|
value string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
t TemplateData
|
||||||
|
args args
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"empty", TemplateData{}, args{"key", "value"}, TemplateData{
|
||||||
|
ExtensionsKey: map[string]interface{}{"key": "value"},
|
||||||
|
}},
|
||||||
|
{"overwrite", TemplateData{
|
||||||
|
ExtensionsKey: map[string]interface{}{"key": "value"},
|
||||||
|
}, args{"key", "value"}, TemplateData{
|
||||||
|
ExtensionsKey: map[string]interface{}{
|
||||||
|
"key": "value",
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
{"add", TemplateData{
|
||||||
|
ExtensionsKey: map[string]interface{}{"foo": "bar"},
|
||||||
|
}, args{"key", "value"}, TemplateData{
|
||||||
|
ExtensionsKey: map[string]interface{}{
|
||||||
|
"key": "value",
|
||||||
|
"foo": "bar",
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.t.AddExtension(tt.args.key, tt.args.value)
|
||||||
|
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||||
|
t.Errorf("AddExtension() = %v, want %v", tt.t, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateData_Set(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
key string
|
||||||
|
v interface{}
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
t TemplateData
|
||||||
|
args args
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"ok", TemplateData{}, args{"foo", "bar"}, TemplateData{
|
||||||
|
"foo": "bar",
|
||||||
|
}},
|
||||||
|
{"overwrite", TemplateData{}, args{"foo", "bar"}, TemplateData{
|
||||||
|
"foo": "bar",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.t.Set(tt.args.key, tt.args.v)
|
||||||
|
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||||
|
t.Errorf("Set() = %v, want %v", tt.t, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateData_SetInsecure(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
key string
|
||||||
|
v interface{}
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
td TemplateData
|
||||||
|
args args
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"empty", TemplateData{}, args{"foo", "bar"}, TemplateData{InsecureKey: TemplateData{"foo": "bar"}}},
|
||||||
|
{"overwrite", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"foo", "zar"}, TemplateData{InsecureKey: TemplateData{"foo": "zar"}}},
|
||||||
|
{"add", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"bar", "foo"}, TemplateData{InsecureKey: TemplateData{"foo": "bar", "bar": "foo"}}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.td.SetInsecure(tt.args.key, tt.args.v)
|
||||||
|
if !reflect.DeepEqual(tt.td, tt.want) {
|
||||||
|
t.Errorf("TemplateData.SetInsecure() = %v, want %v", tt.td, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateData_SetType(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
typ CertType
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
t TemplateData
|
||||||
|
args args
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"user", TemplateData{}, args{UserCert}, TemplateData{
|
||||||
|
TypeKey: "user",
|
||||||
|
}},
|
||||||
|
{"host", TemplateData{}, args{HostCert}, TemplateData{
|
||||||
|
TypeKey: "host",
|
||||||
|
}},
|
||||||
|
{"overwrite", TemplateData{
|
||||||
|
TypeKey: "host",
|
||||||
|
}, args{UserCert}, TemplateData{
|
||||||
|
TypeKey: "user",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.t.SetType(tt.args.typ)
|
||||||
|
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||||
|
t.Errorf("SetType() = %v, want %v", tt.t, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateData_SetKeyID(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
t TemplateData
|
||||||
|
args args
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"ok", TemplateData{}, args{"key-id"}, TemplateData{
|
||||||
|
KeyIDKey: "key-id",
|
||||||
|
}},
|
||||||
|
{"overwrite", TemplateData{}, args{"key-id-2"}, TemplateData{
|
||||||
|
KeyIDKey: "key-id-2",
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.t.SetKeyID(tt.args.id)
|
||||||
|
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||||
|
t.Errorf("SetKeyID() = %v, want %v", tt.t, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateData_SetPrincipals(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
p []string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
t TemplateData
|
||||||
|
args args
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"ok", TemplateData{}, args{[]string{"jane"}}, TemplateData{
|
||||||
|
PrincipalsKey: []string{"jane"},
|
||||||
|
}},
|
||||||
|
{"overwrite", TemplateData{}, args{[]string{"john", "john@doe.com"}}, TemplateData{
|
||||||
|
PrincipalsKey: []string{"john", "john@doe.com"},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.t.SetPrincipals(tt.args.p)
|
||||||
|
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||||
|
t.Errorf("SetPrincipals() = %v, want %v", tt.t, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateData_SetExtensions(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
e map[string]interface{}
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
t TemplateData
|
||||||
|
args args
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"ok", TemplateData{}, args{map[string]interface{}{"foo": "bar"}}, TemplateData{
|
||||||
|
ExtensionsKey: map[string]interface{}{"foo": "bar"},
|
||||||
|
}},
|
||||||
|
{"overwrite", TemplateData{
|
||||||
|
ExtensionsKey: map[string]interface{}{"foo": "bar"},
|
||||||
|
}, args{map[string]interface{}{"key": "value"}}, TemplateData{
|
||||||
|
ExtensionsKey: map[string]interface{}{"key": "value"},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.t.SetExtensions(tt.args.e)
|
||||||
|
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||||
|
t.Errorf("SetExtensions() = %v, want %v", tt.t, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateData_SetToken(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
v interface{}
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
td TemplateData
|
||||||
|
args args
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"ok", TemplateData{}, args{"token"}, TemplateData{TokenKey: "token"}},
|
||||||
|
{"overwrite", TemplateData{TokenKey: "foo"}, args{"token"}, TemplateData{TokenKey: "token"}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.td.SetToken(tt.args.v)
|
||||||
|
if !reflect.DeepEqual(tt.td, tt.want) {
|
||||||
|
t.Errorf("TemplateData.SetToken() = %v, want %v", tt.td, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateData_SetUserData(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
v interface{}
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
td TemplateData
|
||||||
|
args args
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"ok", TemplateData{}, args{"userData"}, TemplateData{InsecureKey: TemplateData{UserKey: "userData"}}},
|
||||||
|
{"overwrite", TemplateData{InsecureKey: TemplateData{UserKey: "foo"}}, args{"userData"}, TemplateData{InsecureKey: TemplateData{UserKey: "userData"}}},
|
||||||
|
{"existing", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"userData"}, TemplateData{InsecureKey: TemplateData{"foo": "bar", UserKey: "userData"}}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.td.SetUserData(tt.args.v)
|
||||||
|
if !reflect.DeepEqual(tt.td, tt.want) {
|
||||||
|
t.Errorf("TemplateData.SetUserData() = %v, want %v", tt.td, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateData_SetPublicKey(t *testing.T) {
|
||||||
|
k1 := mustGeneratePublicKey(t)
|
||||||
|
k2 := mustGeneratePublicKey(t)
|
||||||
|
type args struct {
|
||||||
|
v ssh.PublicKey
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
t TemplateData
|
||||||
|
args args
|
||||||
|
want TemplateData
|
||||||
|
}{
|
||||||
|
{"ok", TemplateData{}, args{k1}, TemplateData{
|
||||||
|
PublicKey: k1,
|
||||||
|
}},
|
||||||
|
{"overwrite", TemplateData{
|
||||||
|
PublicKey: k1,
|
||||||
|
}, args{k2}, TemplateData{
|
||||||
|
PublicKey: k2,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.t.SetPublicKey(tt.args.v)
|
||||||
|
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||||
|
t.Errorf("TemplateData.SetPublicKey() = %v, want %v", tt.t, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
10
sshutil/testdata/github.tpl
vendored
Normal file
10
sshutil/testdata/github.tpl
vendored
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
{
|
||||||
|
"type": "{{ .Type }}",
|
||||||
|
"keyId": "{{ .KeyID }}",
|
||||||
|
"principals": {{ toJson .Principals }},
|
||||||
|
{{- if .Insecure.User.username }}
|
||||||
|
"extensions": {{ set .Extensions "login@github.com" .Insecure.User.username | toJson }}
|
||||||
|
{{- else }}
|
||||||
|
"extensions": {{ toJson .Extensions }}
|
||||||
|
{{- end }}
|
||||||
|
}
|
|
@ -1,5 +1,13 @@
|
||||||
package sshutil
|
package sshutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
// Hosts are tagged with k,v pairs. These tags are how a user is ultimately
|
// Hosts are tagged with k,v pairs. These tags are how a user is ultimately
|
||||||
// associated with a host.
|
// associated with a host.
|
||||||
type HostTag struct {
|
type HostTag struct {
|
||||||
|
@ -14,3 +22,60 @@ type Host struct {
|
||||||
HostTags []HostTag `json:"host_tags"`
|
HostTags []HostTag `json:"host_tags"`
|
||||||
Hostname string `json:"hostname"`
|
Hostname string `json:"hostname"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CertType defines the certificate type, it can be a user or a host
|
||||||
|
// certificate.
|
||||||
|
type CertType uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
// UserCert defines a user certificate.
|
||||||
|
UserCert CertType = ssh.UserCert
|
||||||
|
|
||||||
|
// HostCert defines a host certificate.
|
||||||
|
HostCert CertType = ssh.HostCert
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
userString = "user"
|
||||||
|
hostString = "host"
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns "user" for user certificates and "host" for host certificates.
|
||||||
|
// It will return the empty string for any other value.
|
||||||
|
func (c CertType) String() string {
|
||||||
|
switch c {
|
||||||
|
case UserCert:
|
||||||
|
return userString
|
||||||
|
case HostCert:
|
||||||
|
return hostString
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the json.Marshaler interface for CertType. UserCert
|
||||||
|
// will be marshaled as the string "user" and HostCert as "host".
|
||||||
|
func (c CertType) MarshalJSON() ([]byte, error) {
|
||||||
|
if s := c.String(); s != "" {
|
||||||
|
return []byte(`"` + s + `"`), nil
|
||||||
|
}
|
||||||
|
return nil, errors.Errorf("unknown certificate type %d", c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the json.Unmarshaler interface for CertType.
|
||||||
|
func (c *CertType) UnmarshalJSON(data []byte) error {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err != nil {
|
||||||
|
return errors.Wrap(err, "error unmarshaling certificate type")
|
||||||
|
}
|
||||||
|
switch strings.ToLower(s) {
|
||||||
|
case userString:
|
||||||
|
*c = UserCert
|
||||||
|
return nil
|
||||||
|
case hostString:
|
||||||
|
*c = HostCert
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return errors.Errorf("error unmarshaling '%s' as a certificate type", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
82
sshutil/types_test.go
Normal file
82
sshutil/types_test.go
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
package sshutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCertType_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
c CertType
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"user", UserCert, "user"},
|
||||||
|
{"host", HostCert, "host"},
|
||||||
|
{"empty", 100, ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.c.String(); got != tt.want {
|
||||||
|
t.Errorf("CertType.String() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertType_MarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
c CertType
|
||||||
|
want []byte
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"user", UserCert, []byte(`"user"`), false},
|
||||||
|
{"host", HostCert, []byte(`"host"`), false},
|
||||||
|
{"error", 100, nil, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := tt.c.MarshalJSON()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("CertType.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("CertType.MarshalJSON() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertType_UnmarshalJSON(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want CertType
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"user", args{[]byte(`"user"`)}, UserCert, false},
|
||||||
|
{"USER", args{[]byte(`"USER"`)}, UserCert, false},
|
||||||
|
{"host", args{[]byte(`"host"`)}, HostCert, false},
|
||||||
|
{"HosT", args{[]byte(`"HosT"`)}, HostCert, false},
|
||||||
|
{" user ", args{[]byte(`" user "`)}, 0, true},
|
||||||
|
{"number", args{[]byte(`1`)}, 0, true},
|
||||||
|
{"object", args{[]byte(`{}`)}, 0, true},
|
||||||
|
{"badJSON", args{[]byte(`"user`)}, 0, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var ct CertType
|
||||||
|
if err := ct.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("CertType.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(ct, tt.want) {
|
||||||
|
t.Errorf("CertType.UnmarshalJSON() = %v, want %v", ct, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue