Add package to generate ssh certificate for templates.
This commit is contained in:
parent
3e80f41c19
commit
af3eeb870e
9 changed files with 1413 additions and 0 deletions
104
sshutil/certificate.go
Normal file
104
sshutil/certificate.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package sshutil
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/crypto/randutil"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Certificate is the json representation of ssh.Certificate.
|
||||
type Certificate struct {
|
||||
Nonce []byte `json:"nonce"`
|
||||
Key ssh.PublicKey `json:"-"`
|
||||
Serial uint64 `json:"serial"`
|
||||
Type CertType `json:"type"`
|
||||
KeyID string `json:"keyId"`
|
||||
Principals []string `json:"principals"`
|
||||
ValidAfter uint64 `json:"-"`
|
||||
ValidBefore uint64 `json:"-"`
|
||||
CriticalOptions map[string]string `json:"criticalOptions"`
|
||||
Extensions map[string]string `json:"extensions"`
|
||||
Reserved []byte `json:"reserved"`
|
||||
SignatureKey ssh.PublicKey `json:"-"`
|
||||
Signature *ssh.Signature `json:"-"`
|
||||
}
|
||||
|
||||
// NewCertificate creates a new certificate with the given key after parsing a
|
||||
// template given in the options.
|
||||
func NewCertificate(key ssh.PublicKey, opts ...Option) (*Certificate, error) {
|
||||
o, err := new(Options).apply(key, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if o.CertBuffer == nil {
|
||||
return nil, errors.New("certificate template cannot be empty")
|
||||
}
|
||||
|
||||
// With templates
|
||||
var cert Certificate
|
||||
if err := json.NewDecoder(o.CertBuffer).Decode(&cert); err != nil {
|
||||
return nil, errors.Wrap(err, "error unmarshaling certificate")
|
||||
}
|
||||
|
||||
// Complete with public key
|
||||
cert.Key = key
|
||||
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
func (c *Certificate) GetCertificate() *ssh.Certificate {
|
||||
return &ssh.Certificate{
|
||||
Nonce: c.Nonce,
|
||||
Key: c.Key,
|
||||
Serial: c.Serial,
|
||||
CertType: uint32(c.Type),
|
||||
KeyId: c.KeyID,
|
||||
ValidPrincipals: c.Principals,
|
||||
ValidAfter: c.ValidAfter,
|
||||
ValidBefore: c.ValidBefore,
|
||||
Permissions: ssh.Permissions{
|
||||
CriticalOptions: c.CriticalOptions,
|
||||
Extensions: c.Extensions,
|
||||
},
|
||||
Reserved: c.Reserved,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCertificate signs the given certificate with the given signer. If the
|
||||
// certificate does not have a nonce or a serial, it will create random ones.
|
||||
func CreateCertificate(cert *ssh.Certificate, signer ssh.Signer) (*ssh.Certificate, error) {
|
||||
if len(cert.Nonce) == 0 {
|
||||
nonce, err := randutil.ASCII(32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cert.Nonce = []byte(nonce)
|
||||
}
|
||||
|
||||
if cert.Serial == 0 {
|
||||
if err := binary.Read(rand.Reader, binary.BigEndian, &cert.Serial); err != nil {
|
||||
return nil, errors.Wrap(err, "error reading random number")
|
||||
}
|
||||
}
|
||||
|
||||
// Set signer public key.
|
||||
cert.SignatureKey = signer.PublicKey()
|
||||
|
||||
// Get bytes for signing trailing the signature length.
|
||||
data := cert.Marshal()
|
||||
data = data[:len(data)-4]
|
||||
|
||||
// Sign the certificate.
|
||||
sig, err := signer.Sign(rand.Reader, data)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error signing certificate")
|
||||
}
|
||||
cert.Signature = sig
|
||||
|
||||
return cert, nil
|
||||
}
|
347
sshutil/certificate_test.go
Normal file
347
sshutil/certificate_test.go
Normal file
|
@ -0,0 +1,347 @@
|
|||
package sshutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func mustGenerateKey(t *testing.T) (ssh.PublicKey, ssh.Signer) {
|
||||
t.Helper()
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
key, err := ssh.NewPublicKey(pub)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
signer, err := ssh.NewSignerFromKey(priv)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return key, signer
|
||||
}
|
||||
|
||||
func mustGeneratePublicKey(t *testing.T) ssh.PublicKey {
|
||||
t.Helper()
|
||||
key, _ := mustGenerateKey(t)
|
||||
return key
|
||||
}
|
||||
|
||||
func TestNewCertificate(t *testing.T) {
|
||||
key := mustGeneratePublicKey(t)
|
||||
|
||||
type args struct {
|
||||
key ssh.PublicKey
|
||||
opts []Option
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Certificate
|
||||
wantErr bool
|
||||
}{
|
||||
{"user", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(UserCert, "jane@doe.com", []string{"jane"}))}}, &Certificate{
|
||||
Nonce: nil,
|
||||
Key: key,
|
||||
Serial: 0,
|
||||
Type: UserCert,
|
||||
KeyID: "jane@doe.com",
|
||||
Principals: []string{"jane"},
|
||||
ValidAfter: 0,
|
||||
ValidBefore: 0,
|
||||
CriticalOptions: nil,
|
||||
Extensions: map[string]string{
|
||||
"permit-X11-forwarding": "",
|
||||
"permit-agent-forwarding": "",
|
||||
"permit-port-forwarding": "",
|
||||
"permit-pty": "",
|
||||
"permit-user-rc": "",
|
||||
},
|
||||
Reserved: nil,
|
||||
SignatureKey: nil,
|
||||
Signature: nil,
|
||||
}, false},
|
||||
{"host", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(HostCert, "foobar", []string{"foo.internal", "bar.internal"}))}}, &Certificate{
|
||||
Nonce: nil,
|
||||
Key: key,
|
||||
Serial: 0,
|
||||
Type: HostCert,
|
||||
KeyID: "foobar",
|
||||
Principals: []string{"foo.internal", "bar.internal"},
|
||||
ValidAfter: 0,
|
||||
ValidBefore: 0,
|
||||
CriticalOptions: nil,
|
||||
Extensions: nil,
|
||||
Reserved: nil,
|
||||
SignatureKey: nil,
|
||||
Signature: nil,
|
||||
}, false},
|
||||
{"file", args{key, []Option{WithTemplateFile("./testdata/github.tpl", TemplateData{
|
||||
TypeKey: UserCert,
|
||||
KeyIDKey: "john@doe.com",
|
||||
PrincipalsKey: []string{"john", "john@doe.com"},
|
||||
ExtensionsKey: DefaultExtensions(UserCert),
|
||||
InsecureKey: map[string]interface{}{
|
||||
"User": map[string]interface{}{"username": "john"},
|
||||
},
|
||||
})}}, &Certificate{
|
||||
Nonce: nil,
|
||||
Key: key,
|
||||
Serial: 0,
|
||||
Type: UserCert,
|
||||
KeyID: "john@doe.com",
|
||||
Principals: []string{"john", "john@doe.com"},
|
||||
ValidAfter: 0,
|
||||
ValidBefore: 0,
|
||||
CriticalOptions: nil,
|
||||
Extensions: map[string]string{
|
||||
"permit-X11-forwarding": "",
|
||||
"permit-agent-forwarding": "",
|
||||
"permit-port-forwarding": "",
|
||||
"permit-pty": "",
|
||||
"permit-user-rc": "",
|
||||
"login@github.com": "john",
|
||||
},
|
||||
Reserved: nil,
|
||||
SignatureKey: nil,
|
||||
Signature: nil,
|
||||
}, false},
|
||||
{"base64", args{key, []Option{WithTemplateBase64(base64.StdEncoding.EncodeToString([]byte(DefaultCertificate)), CreateTemplateData(HostCert, "foo.internal", nil))}}, &Certificate{
|
||||
Nonce: nil,
|
||||
Key: key,
|
||||
Serial: 0,
|
||||
Type: HostCert,
|
||||
KeyID: "foo.internal",
|
||||
Principals: nil,
|
||||
ValidAfter: 0,
|
||||
ValidBefore: 0,
|
||||
CriticalOptions: nil,
|
||||
Extensions: nil,
|
||||
Reserved: nil,
|
||||
SignatureKey: nil,
|
||||
Signature: nil,
|
||||
}, false},
|
||||
{"failNilOptions", args{key, nil}, nil, true},
|
||||
{"failEmptyOptions", args{key, nil}, nil, true},
|
||||
{"badBase64Template", args{key, []Option{WithTemplateBase64("foobar", TemplateData{})}}, nil, true},
|
||||
{"badFileTemplate", args{key, []Option{WithTemplateFile("./testdata/missing.tpl", TemplateData{})}}, nil, true},
|
||||
{"badJsonTemplate", args{key, []Option{WithTemplate(`{"type":{{ .Type }}}`, TemplateData{})}}, nil, true},
|
||||
{"failTemplate", args{key, []Option{WithTemplate(`{{ fail "an error" }}`, TemplateData{})}}, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewCertificate(tt.args.key, tt.args.opts...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewCertificate() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificate_GetCertificate(t *testing.T) {
|
||||
key := mustGeneratePublicKey(t)
|
||||
|
||||
type fields struct {
|
||||
Nonce []byte
|
||||
Key ssh.PublicKey
|
||||
Serial uint64
|
||||
Type CertType
|
||||
KeyID string
|
||||
Principals []string
|
||||
ValidAfter uint64
|
||||
ValidBefore uint64
|
||||
CriticalOptions map[string]string
|
||||
Extensions map[string]string
|
||||
Reserved []byte
|
||||
SignatureKey ssh.PublicKey
|
||||
Signature *ssh.Signature
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want *ssh.Certificate
|
||||
}{
|
||||
{"user", fields{
|
||||
Nonce: []byte("0123456789"),
|
||||
Key: key,
|
||||
Serial: 123,
|
||||
Type: UserCert,
|
||||
KeyID: "key-id",
|
||||
Principals: []string{"john"},
|
||||
ValidAfter: 1111,
|
||||
ValidBefore: 2222,
|
||||
CriticalOptions: map[string]string{"foo": "bar"},
|
||||
Extensions: map[string]string{"login@github.com": "john"},
|
||||
Reserved: []byte("reserved"),
|
||||
SignatureKey: key,
|
||||
Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")},
|
||||
}, &ssh.Certificate{
|
||||
Nonce: []byte("0123456789"),
|
||||
Key: key,
|
||||
Serial: 123,
|
||||
CertType: ssh.UserCert,
|
||||
KeyId: "key-id",
|
||||
ValidPrincipals: []string{"john"},
|
||||
ValidAfter: 1111,
|
||||
ValidBefore: 2222,
|
||||
Permissions: ssh.Permissions{
|
||||
CriticalOptions: map[string]string{"foo": "bar"},
|
||||
Extensions: map[string]string{"login@github.com": "john"},
|
||||
},
|
||||
Reserved: []byte("reserved"),
|
||||
}},
|
||||
{"host", fields{
|
||||
Nonce: []byte("0123456789"),
|
||||
Key: key,
|
||||
Serial: 123,
|
||||
Type: HostCert,
|
||||
KeyID: "key-id",
|
||||
Principals: []string{"foo.internal", "bar.internal"},
|
||||
ValidAfter: 1111,
|
||||
ValidBefore: 2222,
|
||||
CriticalOptions: map[string]string{"foo": "bar"},
|
||||
Extensions: nil,
|
||||
Reserved: []byte("reserved"),
|
||||
SignatureKey: key,
|
||||
Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")},
|
||||
}, &ssh.Certificate{
|
||||
Nonce: []byte("0123456789"),
|
||||
Key: key,
|
||||
Serial: 123,
|
||||
CertType: ssh.HostCert,
|
||||
KeyId: "key-id",
|
||||
ValidPrincipals: []string{"foo.internal", "bar.internal"},
|
||||
ValidAfter: 1111,
|
||||
ValidBefore: 2222,
|
||||
Permissions: ssh.Permissions{
|
||||
CriticalOptions: map[string]string{"foo": "bar"},
|
||||
Extensions: nil,
|
||||
},
|
||||
Reserved: []byte("reserved"),
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Certificate{
|
||||
Nonce: tt.fields.Nonce,
|
||||
Key: tt.fields.Key,
|
||||
Serial: tt.fields.Serial,
|
||||
Type: tt.fields.Type,
|
||||
KeyID: tt.fields.KeyID,
|
||||
Principals: tt.fields.Principals,
|
||||
ValidAfter: tt.fields.ValidAfter,
|
||||
ValidBefore: tt.fields.ValidBefore,
|
||||
CriticalOptions: tt.fields.CriticalOptions,
|
||||
Extensions: tt.fields.Extensions,
|
||||
Reserved: tt.fields.Reserved,
|
||||
SignatureKey: tt.fields.SignatureKey,
|
||||
Signature: tt.fields.Signature,
|
||||
}
|
||||
if got := c.GetCertificate(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Certificate.GetCertificate() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCertificate(t *testing.T) {
|
||||
key, signer := mustGenerateKey(t)
|
||||
type args struct {
|
||||
cert *ssh.Certificate
|
||||
signer ssh.Signer
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{"ok", args{&ssh.Certificate{
|
||||
Nonce: []byte("0123456789"),
|
||||
Key: key,
|
||||
Serial: 123,
|
||||
CertType: ssh.HostCert,
|
||||
KeyId: "foo",
|
||||
ValidPrincipals: []string{"foo.internal"},
|
||||
ValidAfter: 1111,
|
||||
ValidBefore: 2222,
|
||||
Permissions: ssh.Permissions{},
|
||||
Reserved: []byte("reserved"),
|
||||
}, signer}, false},
|
||||
{"emptyNonce", args{&ssh.Certificate{
|
||||
Key: key,
|
||||
Serial: 123,
|
||||
CertType: ssh.UserCert,
|
||||
KeyId: "jane@doe.com",
|
||||
ValidPrincipals: []string{"jane"},
|
||||
ValidAfter: 1111,
|
||||
ValidBefore: 2222,
|
||||
Permissions: ssh.Permissions{},
|
||||
Reserved: []byte("reserved"),
|
||||
}, signer}, false},
|
||||
{"emptySerial", args{&ssh.Certificate{
|
||||
Nonce: []byte("0123456789"),
|
||||
Key: key,
|
||||
CertType: ssh.UserCert,
|
||||
KeyId: "jane@doe.com",
|
||||
ValidPrincipals: []string{"jane"},
|
||||
ValidAfter: 1111,
|
||||
ValidBefore: 2222,
|
||||
Permissions: ssh.Permissions{},
|
||||
Reserved: []byte("reserved"),
|
||||
}, signer}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := CreateCertificate(tt.args.cert, tt.args.signer)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CreateCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != nil {
|
||||
switch {
|
||||
case len(got.Nonce) == 0:
|
||||
t.Errorf("CreateCertificate() nonce should not be empty")
|
||||
case got.Serial == 0:
|
||||
t.Errorf("CreateCertificate() serial should not be 0")
|
||||
case got.Signature == nil:
|
||||
t.Errorf("CreateCertificate() signature should not be nil")
|
||||
case !bytes.Equal(got.SignatureKey.Marshal(), tt.args.signer.PublicKey().Marshal()):
|
||||
t.Errorf("CreateCertificate() signature key is not the expected one")
|
||||
}
|
||||
|
||||
signature := got.Signature
|
||||
got.Signature = nil
|
||||
|
||||
data := got.Marshal()
|
||||
data = data[:len(data)-4]
|
||||
|
||||
sig, err := signer.Sign(rand.Reader, data)
|
||||
if err != nil {
|
||||
t.Errorf("signer.Sign() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify signature
|
||||
got.Signature = signature
|
||||
if err := signer.PublicKey().Verify(data, got.Signature); err != nil {
|
||||
t.Errorf("CreateCertificate() signature verify error = %v", err)
|
||||
}
|
||||
// Verify data with public key in cert
|
||||
if err := got.Verify(data, sig); err != nil {
|
||||
t.Errorf("CreateCertificate() certificate verify error = %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
91
sshutil/options.go
Normal file
91
sshutil/options.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package sshutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io/ioutil"
|
||||
"text/template"
|
||||
|
||||
"github.com/Masterminds/sprig/v3"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/smallstep/cli/config"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func getFuncMap(failMessage *string) template.FuncMap {
|
||||
m := sprig.TxtFuncMap()
|
||||
m["fail"] = func(msg string) (string, error) {
|
||||
*failMessage = msg
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Options are the options that can be passed to NewCertificate.
|
||||
type Options struct {
|
||||
CertBuffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func (o *Options) apply(key ssh.PublicKey, opts []Option) (*Options, error) {
|
||||
for _, fn := range opts {
|
||||
if err := fn(key, o); err != nil {
|
||||
return o, err
|
||||
}
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// Option is the type used as a variadic argument in NewCertificate.
|
||||
type Option func(key ssh.PublicKey, o *Options) error
|
||||
|
||||
// WithTemplate is an options that executes the given template text with the
|
||||
// given data.
|
||||
func WithTemplate(text string, data TemplateData) Option {
|
||||
return func(key ssh.PublicKey, o *Options) error {
|
||||
terr := new(TemplateError)
|
||||
funcMap := getFuncMap(&terr.Message)
|
||||
|
||||
tmpl, err := template.New("template").Funcs(funcMap).Parse(text)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error parsing template")
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
data.SetPublicKey(key)
|
||||
if err := tmpl.Execute(buf, data); err != nil {
|
||||
if terr.Message != "" {
|
||||
return terr
|
||||
}
|
||||
return errors.Wrapf(err, "error executing template")
|
||||
}
|
||||
o.CertBuffer = buf
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithTemplateBase64 is an options that executes the given template base64
|
||||
// string with the given data.
|
||||
func WithTemplateBase64(s string, data TemplateData) Option {
|
||||
return func(key ssh.PublicKey, o *Options) error {
|
||||
b, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "error decoding template")
|
||||
}
|
||||
fn := WithTemplate(string(b), data)
|
||||
return fn(key, o)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTemplateFile is an options that reads the template file and executes it
|
||||
// with the given data.
|
||||
func WithTemplateFile(path string, data TemplateData) Option {
|
||||
return func(key ssh.PublicKey, o *Options) error {
|
||||
filename := config.StepAbs(path)
|
||||
b, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "error reading %s", path)
|
||||
}
|
||||
fn := WithTemplate(string(b), data)
|
||||
return fn(key, o)
|
||||
}
|
||||
}
|
176
sshutil/options_test.go
Normal file
176
sshutil/options_test.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
package sshutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func Test_getFuncMap_fail(t *testing.T) {
|
||||
var failMesage string
|
||||
fns := getFuncMap(&failMesage)
|
||||
fail := fns["fail"].(func(s string) (string, error))
|
||||
s, err := fail("the fail message")
|
||||
if err == nil {
|
||||
t.Errorf("fail() error = %v, wantErr %v", err, errors.New("the fail message"))
|
||||
}
|
||||
if s != "" {
|
||||
t.Errorf("fail() = \"%s\", want \"the fail message\"", s)
|
||||
}
|
||||
if failMesage != "the fail message" {
|
||||
t.Errorf("fail() message = \"%s\", want \"the fail message\"", failMesage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTemplate(t *testing.T) {
|
||||
key := mustGeneratePublicKey(t)
|
||||
|
||||
type args struct {
|
||||
text string
|
||||
data TemplateData
|
||||
key ssh.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want Options
|
||||
wantErr bool
|
||||
}{
|
||||
{"user", args{DefaultCertificate, TemplateData{
|
||||
TypeKey: "user",
|
||||
KeyIDKey: "jane@doe.com",
|
||||
PrincipalsKey: []string{"jane", "jane@doe.com"},
|
||||
ExtensionsKey: DefaultExtensions(UserCert),
|
||||
}, key}, Options{
|
||||
CertBuffer: bytes.NewBufferString(`{
|
||||
"type": "user",
|
||||
"keyId": "jane@doe.com",
|
||||
"principals": ["jane","jane@doe.com"],
|
||||
"extensions": {"permit-X11-forwarding":"","permit-agent-forwarding":"","permit-port-forwarding":"","permit-pty":"","permit-user-rc":""}
|
||||
}`)}, false},
|
||||
{"host", args{DefaultCertificate, TemplateData{
|
||||
TypeKey: "host",
|
||||
KeyIDKey: "foo",
|
||||
PrincipalsKey: []string{"foo.internal"},
|
||||
}, key}, Options{
|
||||
CertBuffer: bytes.NewBufferString(`{
|
||||
"type": "host",
|
||||
"keyId": "foo",
|
||||
"principals": ["foo.internal"],
|
||||
"extensions": null
|
||||
}`)}, false},
|
||||
{"fail", args{`{{ fail "a message" }}`, TemplateData{}, key}, Options{}, true},
|
||||
{"failTemplate", args{`{{ fail "fatal error }}`, TemplateData{}, key}, Options{}, true},
|
||||
{"error", args{`{{ mustHas 3 .Data }}`, TemplateData{
|
||||
"Data": 3,
|
||||
}, key}, Options{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got Options
|
||||
fn := WithTemplate(tt.args.text, tt.args.data)
|
||||
if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr {
|
||||
t.Errorf("WithTemplate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("WithTemplate() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTemplateBase64(t *testing.T) {
|
||||
key := mustGeneratePublicKey(t)
|
||||
|
||||
type args struct {
|
||||
s string
|
||||
data TemplateData
|
||||
key ssh.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want Options
|
||||
wantErr bool
|
||||
}{
|
||||
{"host", args{base64.StdEncoding.EncodeToString([]byte(DefaultCertificate)), TemplateData{
|
||||
TypeKey: "host",
|
||||
KeyIDKey: "foo.internal",
|
||||
PrincipalsKey: []string{"foo.internal", "bar.internal"},
|
||||
ExtensionsKey: map[string]interface{}{"foo": "bar"},
|
||||
}, key}, Options{
|
||||
CertBuffer: bytes.NewBufferString(`{
|
||||
"type": "host",
|
||||
"keyId": "foo.internal",
|
||||
"principals": ["foo.internal","bar.internal"],
|
||||
"extensions": {"foo":"bar"}
|
||||
}`)}, false},
|
||||
{"badBase64", args{"foobar", TemplateData{}, key}, Options{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got Options
|
||||
fn := WithTemplateBase64(tt.args.s, tt.args.data)
|
||||
if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr {
|
||||
t.Errorf("WithTemplateBase64() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("WithTemplateBase64() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithTemplateFile(t *testing.T) {
|
||||
key := mustGeneratePublicKey(t)
|
||||
|
||||
data := TemplateData{
|
||||
TypeKey: "user",
|
||||
KeyIDKey: "jane@doe.com",
|
||||
PrincipalsKey: []string{"jane", "jane@doe.com"},
|
||||
ExtensionsKey: DefaultExtensions(UserCert),
|
||||
InsecureKey: map[string]interface{}{
|
||||
UserKey: map[string]interface{}{
|
||||
"username": "jane",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
type args struct {
|
||||
path string
|
||||
data TemplateData
|
||||
key ssh.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want Options
|
||||
wantErr bool
|
||||
}{
|
||||
{"github.com", args{"./testdata/github.tpl", data, key}, Options{
|
||||
CertBuffer: bytes.NewBufferString(`{
|
||||
"type": "user",
|
||||
"keyId": "jane@doe.com",
|
||||
"principals": ["jane","jane@doe.com"],
|
||||
"extensions": {"login@github.com":"jane","permit-X11-forwarding":"","permit-agent-forwarding":"","permit-port-forwarding":"","permit-pty":"","permit-user-rc":""}
|
||||
}`),
|
||||
}, false},
|
||||
{"missing", args{"./testdata/missing.tpl", data, key}, Options{}, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got Options
|
||||
fn := WithTemplateFile(tt.args.path, tt.args.data)
|
||||
if err := fn(tt.args.key, &got); (err != nil) != tt.wantErr {
|
||||
t.Errorf("WithTemplateFile() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("WithTemplateFile() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
132
sshutil/templates.go
Normal file
132
sshutil/templates.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
package sshutil
|
||||
|
||||
import "golang.org/x/crypto/ssh"
|
||||
|
||||
const (
|
||||
TypeKey = "Type"
|
||||
KeyIDKey = "KeyID"
|
||||
PrincipalsKey = "Principals"
|
||||
ExtensionsKey = "Extensions"
|
||||
TokenKey = "Token"
|
||||
InsecureKey = "Insecure"
|
||||
UserKey = "User"
|
||||
PublicKey = "PublicKey"
|
||||
)
|
||||
|
||||
// TemplateError represents an error in a template produced by the fail
|
||||
// function.
|
||||
type TemplateError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface and returns the error string when a
|
||||
// template executes the `fail "message"` function.
|
||||
func (e *TemplateError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// TemplateData is an alias for map[string]interface{}. It represents the data
|
||||
// passed to the templates.
|
||||
type TemplateData map[string]interface{}
|
||||
|
||||
// CreateTemplateData returns a TemplateData with the given certificate type,
|
||||
// key id, principals, and the default extensions.
|
||||
func CreateTemplateData(ct CertType, keyID string, principals []string) TemplateData {
|
||||
return TemplateData{
|
||||
TypeKey: ct.String(),
|
||||
KeyIDKey: keyID,
|
||||
PrincipalsKey: principals,
|
||||
ExtensionsKey: DefaultExtensions(ct),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultExtensions returns the default extensions set in an SSH certificate.
|
||||
func DefaultExtensions(ct CertType) map[string]interface{} {
|
||||
switch ct {
|
||||
case UserCert:
|
||||
return map[string]interface{}{
|
||||
"permit-X11-forwarding": "",
|
||||
"permit-agent-forwarding": "",
|
||||
"permit-port-forwarding": "",
|
||||
"permit-pty": "",
|
||||
"permit-user-rc": "",
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewTemplateData creates a new map for templates data.
|
||||
func NewTemplateData() TemplateData {
|
||||
return TemplateData{}
|
||||
}
|
||||
|
||||
// AddExtension adds one extension to the templates data.
|
||||
func (t TemplateData) AddExtension(key, value string) {
|
||||
if m, ok := t[ExtensionsKey].(map[string]interface{}); ok {
|
||||
m[key] = value
|
||||
} else {
|
||||
t[ExtensionsKey] = map[string]interface{}{
|
||||
key: value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a key-value pair in the template data.
|
||||
func (t TemplateData) Set(key string, v interface{}) {
|
||||
t[key] = v
|
||||
}
|
||||
|
||||
// SetInsecure sets a key-value pair in the insecure template data.
|
||||
func (t TemplateData) SetInsecure(key string, v interface{}) {
|
||||
if m, ok := t[InsecureKey].(TemplateData); ok {
|
||||
m[key] = v
|
||||
} else {
|
||||
t[InsecureKey] = TemplateData{key: v}
|
||||
}
|
||||
}
|
||||
|
||||
// SetType sets the certificate type in the template data.
|
||||
func (t TemplateData) SetType(typ CertType) {
|
||||
t.Set(TypeKey, typ.String())
|
||||
}
|
||||
|
||||
// SetType sets the certificate key id in the template data.
|
||||
func (t TemplateData) SetKeyID(id string) {
|
||||
t.Set(KeyIDKey, id)
|
||||
}
|
||||
|
||||
// SetPrincipals sets the certificate principals in the template data.
|
||||
func (t TemplateData) SetPrincipals(p []string) {
|
||||
t.Set(PrincipalsKey, p)
|
||||
}
|
||||
|
||||
// SetExtensions sets the certificate extensions in the template data.
|
||||
func (t TemplateData) SetExtensions(e map[string]interface{}) {
|
||||
t.Set(ExtensionsKey, e)
|
||||
}
|
||||
|
||||
// SetToken sets the given token in the template data.
|
||||
func (t TemplateData) SetToken(v interface{}) {
|
||||
t.Set(TokenKey, v)
|
||||
}
|
||||
|
||||
// SetUserData sets the given user provided object in the insecure template
|
||||
// data.
|
||||
func (t TemplateData) SetUserData(v interface{}) {
|
||||
t.SetInsecure(UserKey, v)
|
||||
}
|
||||
|
||||
// SetUserData sets the given user provided object in the insecure template
|
||||
// data.
|
||||
func (t TemplateData) SetPublicKey(v ssh.PublicKey) {
|
||||
t.Set(PublicKey, v)
|
||||
}
|
||||
|
||||
// DefaultCertificate is the default template for an SSH certificate.
|
||||
const DefaultCertificate = `{
|
||||
"type": "{{ .Type }}",
|
||||
"keyId": "{{ .KeyID }}",
|
||||
"principals": {{ toJson .Principals }},
|
||||
"extensions": {{ toJson .Extensions }}
|
||||
}`
|
406
sshutil/templates_test.go
Normal file
406
sshutil/templates_test.go
Normal file
|
@ -0,0 +1,406 @@
|
|||
package sshutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestTemplateError_Error(t *testing.T) {
|
||||
type fields struct {
|
||||
Message string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want string
|
||||
}{
|
||||
{"ok", fields{"message"}, "message"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &TemplateError{
|
||||
Message: tt.fields.Message,
|
||||
}
|
||||
if got := e.Error(); got != tt.want {
|
||||
t.Errorf("TemplateError.Error() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTemplateData(t *testing.T) {
|
||||
type args struct {
|
||||
ct CertType
|
||||
keyID string
|
||||
principals []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want TemplateData
|
||||
}{
|
||||
{"user", args{UserCert, "john@doe.com", []string{"john", "john@doe.com"}}, TemplateData{
|
||||
TypeKey: "user",
|
||||
KeyIDKey: "john@doe.com",
|
||||
PrincipalsKey: []string{"john", "john@doe.com"},
|
||||
ExtensionsKey: map[string]interface{}{
|
||||
"permit-X11-forwarding": "",
|
||||
"permit-agent-forwarding": "",
|
||||
"permit-port-forwarding": "",
|
||||
"permit-pty": "",
|
||||
"permit-user-rc": "",
|
||||
},
|
||||
}},
|
||||
{"host", args{HostCert, "foo", []string{"foo.internal"}}, TemplateData{
|
||||
TypeKey: "host",
|
||||
KeyIDKey: "foo",
|
||||
PrincipalsKey: []string{"foo.internal"},
|
||||
ExtensionsKey: map[string]interface{}(nil),
|
||||
}},
|
||||
{"other", args{100, "foo", []string{"foo.internal"}}, TemplateData{
|
||||
TypeKey: "",
|
||||
KeyIDKey: "foo",
|
||||
PrincipalsKey: []string{"foo.internal"},
|
||||
ExtensionsKey: map[string]interface{}(nil),
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := CreateTemplateData(tt.args.ct, tt.args.keyID, tt.args.principals); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CreateTemplateData() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultExtensions(t *testing.T) {
|
||||
type args struct {
|
||||
ct CertType
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want map[string]interface{}
|
||||
}{
|
||||
{"user", args{UserCert}, map[string]interface{}{
|
||||
"permit-X11-forwarding": "",
|
||||
"permit-agent-forwarding": "",
|
||||
"permit-port-forwarding": "",
|
||||
"permit-pty": "",
|
||||
"permit-user-rc": "",
|
||||
}},
|
||||
{"host", args{HostCert}, nil},
|
||||
{"other", args{100}, nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := DefaultExtensions(tt.args.ct); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("DefaultExtensions() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTemplateData(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want TemplateData
|
||||
}{
|
||||
{"ok", TemplateData{}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := NewTemplateData(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewTemplateData() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateData_AddExtension(t *testing.T) {
|
||||
type args struct {
|
||||
key string
|
||||
value string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
t TemplateData
|
||||
args args
|
||||
want TemplateData
|
||||
}{
|
||||
{"empty", TemplateData{}, args{"key", "value"}, TemplateData{
|
||||
ExtensionsKey: map[string]interface{}{"key": "value"},
|
||||
}},
|
||||
{"overwrite", TemplateData{
|
||||
ExtensionsKey: map[string]interface{}{"key": "value"},
|
||||
}, args{"key", "value"}, TemplateData{
|
||||
ExtensionsKey: map[string]interface{}{
|
||||
"key": "value",
|
||||
},
|
||||
}},
|
||||
{"add", TemplateData{
|
||||
ExtensionsKey: map[string]interface{}{"foo": "bar"},
|
||||
}, args{"key", "value"}, TemplateData{
|
||||
ExtensionsKey: map[string]interface{}{
|
||||
"key": "value",
|
||||
"foo": "bar",
|
||||
},
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.t.AddExtension(tt.args.key, tt.args.value)
|
||||
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||
t.Errorf("AddExtension() = %v, want %v", tt.t, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateData_Set(t *testing.T) {
|
||||
type args struct {
|
||||
key string
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
t TemplateData
|
||||
args args
|
||||
want TemplateData
|
||||
}{
|
||||
{"ok", TemplateData{}, args{"foo", "bar"}, TemplateData{
|
||||
"foo": "bar",
|
||||
}},
|
||||
{"overwrite", TemplateData{}, args{"foo", "bar"}, TemplateData{
|
||||
"foo": "bar",
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.t.Set(tt.args.key, tt.args.v)
|
||||
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||
t.Errorf("Set() = %v, want %v", tt.t, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateData_SetInsecure(t *testing.T) {
|
||||
type args struct {
|
||||
key string
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
td TemplateData
|
||||
args args
|
||||
want TemplateData
|
||||
}{
|
||||
{"empty", TemplateData{}, args{"foo", "bar"}, TemplateData{InsecureKey: TemplateData{"foo": "bar"}}},
|
||||
{"overwrite", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"foo", "zar"}, TemplateData{InsecureKey: TemplateData{"foo": "zar"}}},
|
||||
{"add", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"bar", "foo"}, TemplateData{InsecureKey: TemplateData{"foo": "bar", "bar": "foo"}}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.td.SetInsecure(tt.args.key, tt.args.v)
|
||||
if !reflect.DeepEqual(tt.td, tt.want) {
|
||||
t.Errorf("TemplateData.SetInsecure() = %v, want %v", tt.td, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateData_SetType(t *testing.T) {
|
||||
type args struct {
|
||||
typ CertType
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
t TemplateData
|
||||
args args
|
||||
want TemplateData
|
||||
}{
|
||||
{"user", TemplateData{}, args{UserCert}, TemplateData{
|
||||
TypeKey: "user",
|
||||
}},
|
||||
{"host", TemplateData{}, args{HostCert}, TemplateData{
|
||||
TypeKey: "host",
|
||||
}},
|
||||
{"overwrite", TemplateData{
|
||||
TypeKey: "host",
|
||||
}, args{UserCert}, TemplateData{
|
||||
TypeKey: "user",
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.t.SetType(tt.args.typ)
|
||||
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||
t.Errorf("SetType() = %v, want %v", tt.t, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateData_SetKeyID(t *testing.T) {
|
||||
type args struct {
|
||||
id string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
t TemplateData
|
||||
args args
|
||||
want TemplateData
|
||||
}{
|
||||
{"ok", TemplateData{}, args{"key-id"}, TemplateData{
|
||||
KeyIDKey: "key-id",
|
||||
}},
|
||||
{"overwrite", TemplateData{}, args{"key-id-2"}, TemplateData{
|
||||
KeyIDKey: "key-id-2",
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.t.SetKeyID(tt.args.id)
|
||||
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||
t.Errorf("SetKeyID() = %v, want %v", tt.t, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateData_SetPrincipals(t *testing.T) {
|
||||
type args struct {
|
||||
p []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
t TemplateData
|
||||
args args
|
||||
want TemplateData
|
||||
}{
|
||||
{"ok", TemplateData{}, args{[]string{"jane"}}, TemplateData{
|
||||
PrincipalsKey: []string{"jane"},
|
||||
}},
|
||||
{"overwrite", TemplateData{}, args{[]string{"john", "john@doe.com"}}, TemplateData{
|
||||
PrincipalsKey: []string{"john", "john@doe.com"},
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.t.SetPrincipals(tt.args.p)
|
||||
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||
t.Errorf("SetPrincipals() = %v, want %v", tt.t, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateData_SetExtensions(t *testing.T) {
|
||||
type args struct {
|
||||
e map[string]interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
t TemplateData
|
||||
args args
|
||||
want TemplateData
|
||||
}{
|
||||
{"ok", TemplateData{}, args{map[string]interface{}{"foo": "bar"}}, TemplateData{
|
||||
ExtensionsKey: map[string]interface{}{"foo": "bar"},
|
||||
}},
|
||||
{"overwrite", TemplateData{
|
||||
ExtensionsKey: map[string]interface{}{"foo": "bar"},
|
||||
}, args{map[string]interface{}{"key": "value"}}, TemplateData{
|
||||
ExtensionsKey: map[string]interface{}{"key": "value"},
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.t.SetExtensions(tt.args.e)
|
||||
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||
t.Errorf("SetExtensions() = %v, want %v", tt.t, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateData_SetToken(t *testing.T) {
|
||||
type args struct {
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
td TemplateData
|
||||
args args
|
||||
want TemplateData
|
||||
}{
|
||||
{"ok", TemplateData{}, args{"token"}, TemplateData{TokenKey: "token"}},
|
||||
{"overwrite", TemplateData{TokenKey: "foo"}, args{"token"}, TemplateData{TokenKey: "token"}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.td.SetToken(tt.args.v)
|
||||
if !reflect.DeepEqual(tt.td, tt.want) {
|
||||
t.Errorf("TemplateData.SetToken() = %v, want %v", tt.td, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateData_SetUserData(t *testing.T) {
|
||||
type args struct {
|
||||
v interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
td TemplateData
|
||||
args args
|
||||
want TemplateData
|
||||
}{
|
||||
{"ok", TemplateData{}, args{"userData"}, TemplateData{InsecureKey: TemplateData{UserKey: "userData"}}},
|
||||
{"overwrite", TemplateData{InsecureKey: TemplateData{UserKey: "foo"}}, args{"userData"}, TemplateData{InsecureKey: TemplateData{UserKey: "userData"}}},
|
||||
{"existing", TemplateData{InsecureKey: TemplateData{"foo": "bar"}}, args{"userData"}, TemplateData{InsecureKey: TemplateData{"foo": "bar", UserKey: "userData"}}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.td.SetUserData(tt.args.v)
|
||||
if !reflect.DeepEqual(tt.td, tt.want) {
|
||||
t.Errorf("TemplateData.SetUserData() = %v, want %v", tt.td, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateData_SetPublicKey(t *testing.T) {
|
||||
k1 := mustGeneratePublicKey(t)
|
||||
k2 := mustGeneratePublicKey(t)
|
||||
type args struct {
|
||||
v ssh.PublicKey
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
t TemplateData
|
||||
args args
|
||||
want TemplateData
|
||||
}{
|
||||
{"ok", TemplateData{}, args{k1}, TemplateData{
|
||||
PublicKey: k1,
|
||||
}},
|
||||
{"overwrite", TemplateData{
|
||||
PublicKey: k1,
|
||||
}, args{k2}, TemplateData{
|
||||
PublicKey: k2,
|
||||
}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.t.SetPublicKey(tt.args.v)
|
||||
if !reflect.DeepEqual(tt.t, tt.want) {
|
||||
t.Errorf("TemplateData.SetPublicKey() = %v, want %v", tt.t, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
10
sshutil/testdata/github.tpl
vendored
Normal file
10
sshutil/testdata/github.tpl
vendored
Normal file
|
@ -0,0 +1,10 @@
|
|||
{
|
||||
"type": "{{ .Type }}",
|
||||
"keyId": "{{ .KeyID }}",
|
||||
"principals": {{ toJson .Principals }},
|
||||
{{- if .Insecure.User.username }}
|
||||
"extensions": {{ set .Extensions "login@github.com" .Insecure.User.username | toJson }}
|
||||
{{- else }}
|
||||
"extensions": {{ toJson .Extensions }}
|
||||
{{- end }}
|
||||
}
|
|
@ -1,5 +1,13 @@
|
|||
package sshutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// Hosts are tagged with k,v pairs. These tags are how a user is ultimately
|
||||
// associated with a host.
|
||||
type HostTag struct {
|
||||
|
@ -14,3 +22,60 @@ type Host struct {
|
|||
HostTags []HostTag `json:"host_tags"`
|
||||
Hostname string `json:"hostname"`
|
||||
}
|
||||
|
||||
// CertType defines the certificate type, it can be a user or a host
|
||||
// certificate.
|
||||
type CertType uint32
|
||||
|
||||
const (
|
||||
// UserCert defines a user certificate.
|
||||
UserCert CertType = ssh.UserCert
|
||||
|
||||
// HostCert defines a host certificate.
|
||||
HostCert CertType = ssh.HostCert
|
||||
)
|
||||
|
||||
const (
|
||||
userString = "user"
|
||||
hostString = "host"
|
||||
)
|
||||
|
||||
// String returns "user" for user certificates and "host" for host certificates.
|
||||
// It will return the empty string for any other value.
|
||||
func (c CertType) String() string {
|
||||
switch c {
|
||||
case UserCert:
|
||||
return userString
|
||||
case HostCert:
|
||||
return hostString
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface for CertType. UserCert
|
||||
// will be marshaled as the string "user" and HostCert as "host".
|
||||
func (c CertType) MarshalJSON() ([]byte, error) {
|
||||
if s := c.String(); s != "" {
|
||||
return []byte(`"` + s + `"`), nil
|
||||
}
|
||||
return nil, errors.Errorf("unknown certificate type %d", c)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface for CertType.
|
||||
func (c *CertType) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrap(err, "error unmarshaling certificate type")
|
||||
}
|
||||
switch strings.ToLower(s) {
|
||||
case userString:
|
||||
*c = UserCert
|
||||
return nil
|
||||
case hostString:
|
||||
*c = HostCert
|
||||
return nil
|
||||
default:
|
||||
return errors.Errorf("error unmarshaling '%s' as a certificate type", s)
|
||||
}
|
||||
}
|
||||
|
|
82
sshutil/types_test.go
Normal file
82
sshutil/types_test.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
package sshutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCertType_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
c CertType
|
||||
want string
|
||||
}{
|
||||
{"user", UserCert, "user"},
|
||||
{"host", HostCert, "host"},
|
||||
{"empty", 100, ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.c.String(); got != tt.want {
|
||||
t.Errorf("CertType.String() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertType_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
c CertType
|
||||
want []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{"user", UserCert, []byte(`"user"`), false},
|
||||
{"host", HostCert, []byte(`"host"`), false},
|
||||
{"error", 100, nil, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.c.MarshalJSON()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CertType.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CertType.MarshalJSON() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertType_UnmarshalJSON(t *testing.T) {
|
||||
type args struct {
|
||||
data []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want CertType
|
||||
wantErr bool
|
||||
}{
|
||||
{"user", args{[]byte(`"user"`)}, UserCert, false},
|
||||
{"USER", args{[]byte(`"USER"`)}, UserCert, false},
|
||||
{"host", args{[]byte(`"host"`)}, HostCert, false},
|
||||
{"HosT", args{[]byte(`"HosT"`)}, HostCert, false},
|
||||
{" user ", args{[]byte(`" user "`)}, 0, true},
|
||||
{"number", args{[]byte(`1`)}, 0, true},
|
||||
{"object", args{[]byte(`{}`)}, 0, true},
|
||||
{"badJSON", args{[]byte(`"user`)}, 0, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var ct CertType
|
||||
if err := ct.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
|
||||
t.Errorf("CertType.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if !reflect.DeepEqual(ct, tt.want) {
|
||||
t.Errorf("CertType.UnmarshalJSON() = %v, want %v", ct, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue