forked from TrueCloudLab/certificates
348 lines
10 KiB
Go
348 lines
10 KiB
Go
|
package sshutil
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"crypto/ed25519"
|
||
|
"crypto/rand"
|
||
|
"encoding/base64"
|
||
|
"reflect"
|
||
|
"testing"
|
||
|
|
||
|
"golang.org/x/crypto/ssh"
|
||
|
)
|
||
|
|
||
|
func mustGenerateKey(t *testing.T) (ssh.PublicKey, ssh.Signer) {
|
||
|
t.Helper()
|
||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
key, err := ssh.NewPublicKey(pub)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
signer, err := ssh.NewSignerFromKey(priv)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
return key, signer
|
||
|
}
|
||
|
|
||
|
func mustGeneratePublicKey(t *testing.T) ssh.PublicKey {
|
||
|
t.Helper()
|
||
|
key, _ := mustGenerateKey(t)
|
||
|
return key
|
||
|
}
|
||
|
|
||
|
func TestNewCertificate(t *testing.T) {
|
||
|
key := mustGeneratePublicKey(t)
|
||
|
|
||
|
type args struct {
|
||
|
key ssh.PublicKey
|
||
|
opts []Option
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
args args
|
||
|
want *Certificate
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{"user", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(UserCert, "jane@doe.com", []string{"jane"}))}}, &Certificate{
|
||
|
Nonce: nil,
|
||
|
Key: key,
|
||
|
Serial: 0,
|
||
|
Type: UserCert,
|
||
|
KeyID: "jane@doe.com",
|
||
|
Principals: []string{"jane"},
|
||
|
ValidAfter: 0,
|
||
|
ValidBefore: 0,
|
||
|
CriticalOptions: nil,
|
||
|
Extensions: map[string]string{
|
||
|
"permit-X11-forwarding": "",
|
||
|
"permit-agent-forwarding": "",
|
||
|
"permit-port-forwarding": "",
|
||
|
"permit-pty": "",
|
||
|
"permit-user-rc": "",
|
||
|
},
|
||
|
Reserved: nil,
|
||
|
SignatureKey: nil,
|
||
|
Signature: nil,
|
||
|
}, false},
|
||
|
{"host", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(HostCert, "foobar", []string{"foo.internal", "bar.internal"}))}}, &Certificate{
|
||
|
Nonce: nil,
|
||
|
Key: key,
|
||
|
Serial: 0,
|
||
|
Type: HostCert,
|
||
|
KeyID: "foobar",
|
||
|
Principals: []string{"foo.internal", "bar.internal"},
|
||
|
ValidAfter: 0,
|
||
|
ValidBefore: 0,
|
||
|
CriticalOptions: nil,
|
||
|
Extensions: nil,
|
||
|
Reserved: nil,
|
||
|
SignatureKey: nil,
|
||
|
Signature: nil,
|
||
|
}, false},
|
||
|
{"file", args{key, []Option{WithTemplateFile("./testdata/github.tpl", TemplateData{
|
||
|
TypeKey: UserCert,
|
||
|
KeyIDKey: "john@doe.com",
|
||
|
PrincipalsKey: []string{"john", "john@doe.com"},
|
||
|
ExtensionsKey: DefaultExtensions(UserCert),
|
||
|
InsecureKey: map[string]interface{}{
|
||
|
"User": map[string]interface{}{"username": "john"},
|
||
|
},
|
||
|
})}}, &Certificate{
|
||
|
Nonce: nil,
|
||
|
Key: key,
|
||
|
Serial: 0,
|
||
|
Type: UserCert,
|
||
|
KeyID: "john@doe.com",
|
||
|
Principals: []string{"john", "john@doe.com"},
|
||
|
ValidAfter: 0,
|
||
|
ValidBefore: 0,
|
||
|
CriticalOptions: nil,
|
||
|
Extensions: map[string]string{
|
||
|
"permit-X11-forwarding": "",
|
||
|
"permit-agent-forwarding": "",
|
||
|
"permit-port-forwarding": "",
|
||
|
"permit-pty": "",
|
||
|
"permit-user-rc": "",
|
||
|
"login@github.com": "john",
|
||
|
},
|
||
|
Reserved: nil,
|
||
|
SignatureKey: nil,
|
||
|
Signature: nil,
|
||
|
}, false},
|
||
|
{"base64", args{key, []Option{WithTemplateBase64(base64.StdEncoding.EncodeToString([]byte(DefaultCertificate)), CreateTemplateData(HostCert, "foo.internal", nil))}}, &Certificate{
|
||
|
Nonce: nil,
|
||
|
Key: key,
|
||
|
Serial: 0,
|
||
|
Type: HostCert,
|
||
|
KeyID: "foo.internal",
|
||
|
Principals: nil,
|
||
|
ValidAfter: 0,
|
||
|
ValidBefore: 0,
|
||
|
CriticalOptions: nil,
|
||
|
Extensions: nil,
|
||
|
Reserved: nil,
|
||
|
SignatureKey: nil,
|
||
|
Signature: nil,
|
||
|
}, false},
|
||
|
{"failNilOptions", args{key, nil}, nil, true},
|
||
|
{"failEmptyOptions", args{key, nil}, nil, true},
|
||
|
{"badBase64Template", args{key, []Option{WithTemplateBase64("foobar", TemplateData{})}}, nil, true},
|
||
|
{"badFileTemplate", args{key, []Option{WithTemplateFile("./testdata/missing.tpl", TemplateData{})}}, nil, true},
|
||
|
{"badJsonTemplate", args{key, []Option{WithTemplate(`{"type":{{ .Type }}}`, TemplateData{})}}, nil, true},
|
||
|
{"failTemplate", args{key, []Option{WithTemplate(`{{ fail "an error" }}`, TemplateData{})}}, nil, true},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
got, err := NewCertificate(tt.args.key, tt.args.opts...)
|
||
|
if (err != nil) != tt.wantErr {
|
||
|
t.Errorf("NewCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
||
|
return
|
||
|
}
|
||
|
if !reflect.DeepEqual(got, tt.want) {
|
||
|
t.Errorf("NewCertificate() = %v, want %v", got, tt.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCertificate_GetCertificate(t *testing.T) {
|
||
|
key := mustGeneratePublicKey(t)
|
||
|
|
||
|
type fields struct {
|
||
|
Nonce []byte
|
||
|
Key ssh.PublicKey
|
||
|
Serial uint64
|
||
|
Type CertType
|
||
|
KeyID string
|
||
|
Principals []string
|
||
|
ValidAfter uint64
|
||
|
ValidBefore uint64
|
||
|
CriticalOptions map[string]string
|
||
|
Extensions map[string]string
|
||
|
Reserved []byte
|
||
|
SignatureKey ssh.PublicKey
|
||
|
Signature *ssh.Signature
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
fields fields
|
||
|
want *ssh.Certificate
|
||
|
}{
|
||
|
{"user", fields{
|
||
|
Nonce: []byte("0123456789"),
|
||
|
Key: key,
|
||
|
Serial: 123,
|
||
|
Type: UserCert,
|
||
|
KeyID: "key-id",
|
||
|
Principals: []string{"john"},
|
||
|
ValidAfter: 1111,
|
||
|
ValidBefore: 2222,
|
||
|
CriticalOptions: map[string]string{"foo": "bar"},
|
||
|
Extensions: map[string]string{"login@github.com": "john"},
|
||
|
Reserved: []byte("reserved"),
|
||
|
SignatureKey: key,
|
||
|
Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")},
|
||
|
}, &ssh.Certificate{
|
||
|
Nonce: []byte("0123456789"),
|
||
|
Key: key,
|
||
|
Serial: 123,
|
||
|
CertType: ssh.UserCert,
|
||
|
KeyId: "key-id",
|
||
|
ValidPrincipals: []string{"john"},
|
||
|
ValidAfter: 1111,
|
||
|
ValidBefore: 2222,
|
||
|
Permissions: ssh.Permissions{
|
||
|
CriticalOptions: map[string]string{"foo": "bar"},
|
||
|
Extensions: map[string]string{"login@github.com": "john"},
|
||
|
},
|
||
|
Reserved: []byte("reserved"),
|
||
|
}},
|
||
|
{"host", fields{
|
||
|
Nonce: []byte("0123456789"),
|
||
|
Key: key,
|
||
|
Serial: 123,
|
||
|
Type: HostCert,
|
||
|
KeyID: "key-id",
|
||
|
Principals: []string{"foo.internal", "bar.internal"},
|
||
|
ValidAfter: 1111,
|
||
|
ValidBefore: 2222,
|
||
|
CriticalOptions: map[string]string{"foo": "bar"},
|
||
|
Extensions: nil,
|
||
|
Reserved: []byte("reserved"),
|
||
|
SignatureKey: key,
|
||
|
Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")},
|
||
|
}, &ssh.Certificate{
|
||
|
Nonce: []byte("0123456789"),
|
||
|
Key: key,
|
||
|
Serial: 123,
|
||
|
CertType: ssh.HostCert,
|
||
|
KeyId: "key-id",
|
||
|
ValidPrincipals: []string{"foo.internal", "bar.internal"},
|
||
|
ValidAfter: 1111,
|
||
|
ValidBefore: 2222,
|
||
|
Permissions: ssh.Permissions{
|
||
|
CriticalOptions: map[string]string{"foo": "bar"},
|
||
|
Extensions: nil,
|
||
|
},
|
||
|
Reserved: []byte("reserved"),
|
||
|
}},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
c := &Certificate{
|
||
|
Nonce: tt.fields.Nonce,
|
||
|
Key: tt.fields.Key,
|
||
|
Serial: tt.fields.Serial,
|
||
|
Type: tt.fields.Type,
|
||
|
KeyID: tt.fields.KeyID,
|
||
|
Principals: tt.fields.Principals,
|
||
|
ValidAfter: tt.fields.ValidAfter,
|
||
|
ValidBefore: tt.fields.ValidBefore,
|
||
|
CriticalOptions: tt.fields.CriticalOptions,
|
||
|
Extensions: tt.fields.Extensions,
|
||
|
Reserved: tt.fields.Reserved,
|
||
|
SignatureKey: tt.fields.SignatureKey,
|
||
|
Signature: tt.fields.Signature,
|
||
|
}
|
||
|
if got := c.GetCertificate(); !reflect.DeepEqual(got, tt.want) {
|
||
|
t.Errorf("Certificate.GetCertificate() = %v, want %v", got, tt.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestCreateCertificate(t *testing.T) {
|
||
|
key, signer := mustGenerateKey(t)
|
||
|
type args struct {
|
||
|
cert *ssh.Certificate
|
||
|
signer ssh.Signer
|
||
|
}
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
args args
|
||
|
wantErr bool
|
||
|
}{
|
||
|
{"ok", args{&ssh.Certificate{
|
||
|
Nonce: []byte("0123456789"),
|
||
|
Key: key,
|
||
|
Serial: 123,
|
||
|
CertType: ssh.HostCert,
|
||
|
KeyId: "foo",
|
||
|
ValidPrincipals: []string{"foo.internal"},
|
||
|
ValidAfter: 1111,
|
||
|
ValidBefore: 2222,
|
||
|
Permissions: ssh.Permissions{},
|
||
|
Reserved: []byte("reserved"),
|
||
|
}, signer}, false},
|
||
|
{"emptyNonce", args{&ssh.Certificate{
|
||
|
Key: key,
|
||
|
Serial: 123,
|
||
|
CertType: ssh.UserCert,
|
||
|
KeyId: "jane@doe.com",
|
||
|
ValidPrincipals: []string{"jane"},
|
||
|
ValidAfter: 1111,
|
||
|
ValidBefore: 2222,
|
||
|
Permissions: ssh.Permissions{},
|
||
|
Reserved: []byte("reserved"),
|
||
|
}, signer}, false},
|
||
|
{"emptySerial", args{&ssh.Certificate{
|
||
|
Nonce: []byte("0123456789"),
|
||
|
Key: key,
|
||
|
CertType: ssh.UserCert,
|
||
|
KeyId: "jane@doe.com",
|
||
|
ValidPrincipals: []string{"jane"},
|
||
|
ValidAfter: 1111,
|
||
|
ValidBefore: 2222,
|
||
|
Permissions: ssh.Permissions{},
|
||
|
Reserved: []byte("reserved"),
|
||
|
}, signer}, false},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
got, err := CreateCertificate(tt.args.cert, tt.args.signer)
|
||
|
if (err != nil) != tt.wantErr {
|
||
|
t.Errorf("CreateCertificate() error = %v, wantErr %v", err, tt.wantErr)
|
||
|
return
|
||
|
}
|
||
|
if got != nil {
|
||
|
switch {
|
||
|
case len(got.Nonce) == 0:
|
||
|
t.Errorf("CreateCertificate() nonce should not be empty")
|
||
|
case got.Serial == 0:
|
||
|
t.Errorf("CreateCertificate() serial should not be 0")
|
||
|
case got.Signature == nil:
|
||
|
t.Errorf("CreateCertificate() signature should not be nil")
|
||
|
case !bytes.Equal(got.SignatureKey.Marshal(), tt.args.signer.PublicKey().Marshal()):
|
||
|
t.Errorf("CreateCertificate() signature key is not the expected one")
|
||
|
}
|
||
|
|
||
|
signature := got.Signature
|
||
|
got.Signature = nil
|
||
|
|
||
|
data := got.Marshal()
|
||
|
data = data[:len(data)-4]
|
||
|
|
||
|
sig, err := signer.Sign(rand.Reader, data)
|
||
|
if err != nil {
|
||
|
t.Errorf("signer.Sign() error = %v", err)
|
||
|
}
|
||
|
|
||
|
// Verify signature
|
||
|
got.Signature = signature
|
||
|
if err := signer.PublicKey().Verify(data, got.Signature); err != nil {
|
||
|
t.Errorf("CreateCertificate() signature verify error = %v", err)
|
||
|
}
|
||
|
// Verify data with public key in cert
|
||
|
if err := got.Verify(data, sig); err != nil {
|
||
|
t.Errorf("CreateCertificate() certificate verify error = %v", err)
|
||
|
}
|
||
|
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|