package sshutil import ( "bytes" "crypto/ed25519" "crypto/rand" "encoding/base64" "reflect" "testing" "golang.org/x/crypto/ssh" ) func mustGenerateKey(t *testing.T) (ssh.PublicKey, ssh.Signer) { t.Helper() pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } key, err := ssh.NewPublicKey(pub) if err != nil { t.Fatal(err) } signer, err := ssh.NewSignerFromKey(priv) if err != nil { t.Fatal(err) } return key, signer } func mustGeneratePublicKey(t *testing.T) ssh.PublicKey { t.Helper() key, _ := mustGenerateKey(t) return key } func TestNewCertificate(t *testing.T) { key := mustGeneratePublicKey(t) type args struct { key ssh.PublicKey opts []Option } tests := []struct { name string args args want *Certificate wantErr bool }{ {"user", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(UserCert, "jane@doe.com", []string{"jane"}))}}, &Certificate{ Nonce: nil, Key: key, Serial: 0, Type: UserCert, KeyID: "jane@doe.com", Principals: []string{"jane"}, ValidAfter: 0, ValidBefore: 0, CriticalOptions: nil, Extensions: map[string]string{ "permit-X11-forwarding": "", "permit-agent-forwarding": "", "permit-port-forwarding": "", "permit-pty": "", "permit-user-rc": "", }, Reserved: nil, SignatureKey: nil, Signature: nil, }, false}, {"host", args{key, []Option{WithTemplate(DefaultCertificate, CreateTemplateData(HostCert, "foobar", []string{"foo.internal", "bar.internal"}))}}, &Certificate{ Nonce: nil, Key: key, Serial: 0, Type: HostCert, KeyID: "foobar", Principals: []string{"foo.internal", "bar.internal"}, ValidAfter: 0, ValidBefore: 0, CriticalOptions: nil, Extensions: nil, Reserved: nil, SignatureKey: nil, Signature: nil, }, false}, {"file", args{key, []Option{WithTemplateFile("./testdata/github.tpl", TemplateData{ TypeKey: UserCert, KeyIDKey: "john@doe.com", PrincipalsKey: []string{"john", "john@doe.com"}, ExtensionsKey: DefaultExtensions(UserCert), InsecureKey: map[string]interface{}{ "User": map[string]interface{}{"username": "john"}, }, })}}, &Certificate{ Nonce: nil, Key: key, Serial: 0, Type: UserCert, KeyID: "john@doe.com", Principals: []string{"john", "john@doe.com"}, ValidAfter: 0, ValidBefore: 0, CriticalOptions: nil, Extensions: map[string]string{ "permit-X11-forwarding": "", "permit-agent-forwarding": "", "permit-port-forwarding": "", "permit-pty": "", "permit-user-rc": "", "login@github.com": "john", }, Reserved: nil, SignatureKey: nil, Signature: nil, }, false}, {"base64", args{key, []Option{WithTemplateBase64(base64.StdEncoding.EncodeToString([]byte(DefaultCertificate)), CreateTemplateData(HostCert, "foo.internal", nil))}}, &Certificate{ Nonce: nil, Key: key, Serial: 0, Type: HostCert, KeyID: "foo.internal", Principals: nil, ValidAfter: 0, ValidBefore: 0, CriticalOptions: nil, Extensions: nil, Reserved: nil, SignatureKey: nil, Signature: nil, }, false}, {"failNilOptions", args{key, nil}, nil, true}, {"failEmptyOptions", args{key, nil}, nil, true}, {"badBase64Template", args{key, []Option{WithTemplateBase64("foobar", TemplateData{})}}, nil, true}, {"badFileTemplate", args{key, []Option{WithTemplateFile("./testdata/missing.tpl", TemplateData{})}}, nil, true}, {"badJsonTemplate", args{key, []Option{WithTemplate(`{"type":{{ .Type }}}`, TemplateData{})}}, nil, true}, {"failTemplate", args{key, []Option{WithTemplate(`{{ fail "an error" }}`, TemplateData{})}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := NewCertificate(tt.args.key, tt.args.opts...) if (err != nil) != tt.wantErr { t.Errorf("NewCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("NewCertificate() = %v, want %v", got, tt.want) } }) } } func TestCertificate_GetCertificate(t *testing.T) { key := mustGeneratePublicKey(t) type fields struct { Nonce []byte Key ssh.PublicKey Serial uint64 Type CertType KeyID string Principals []string ValidAfter uint64 ValidBefore uint64 CriticalOptions map[string]string Extensions map[string]string Reserved []byte SignatureKey ssh.PublicKey Signature *ssh.Signature } tests := []struct { name string fields fields want *ssh.Certificate }{ {"user", fields{ Nonce: []byte("0123456789"), Key: key, Serial: 123, Type: UserCert, KeyID: "key-id", Principals: []string{"john"}, ValidAfter: 1111, ValidBefore: 2222, CriticalOptions: map[string]string{"foo": "bar"}, Extensions: map[string]string{"login@github.com": "john"}, Reserved: []byte("reserved"), SignatureKey: key, Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")}, }, &ssh.Certificate{ Nonce: []byte("0123456789"), Key: key, Serial: 123, CertType: ssh.UserCert, KeyId: "key-id", ValidPrincipals: []string{"john"}, ValidAfter: 1111, ValidBefore: 2222, Permissions: ssh.Permissions{ CriticalOptions: map[string]string{"foo": "bar"}, Extensions: map[string]string{"login@github.com": "john"}, }, Reserved: []byte("reserved"), }}, {"host", fields{ Nonce: []byte("0123456789"), Key: key, Serial: 123, Type: HostCert, KeyID: "key-id", Principals: []string{"foo.internal", "bar.internal"}, ValidAfter: 1111, ValidBefore: 2222, CriticalOptions: map[string]string{"foo": "bar"}, Extensions: nil, Reserved: []byte("reserved"), SignatureKey: key, Signature: &ssh.Signature{Format: "foo", Blob: []byte("bar")}, }, &ssh.Certificate{ Nonce: []byte("0123456789"), Key: key, Serial: 123, CertType: ssh.HostCert, KeyId: "key-id", ValidPrincipals: []string{"foo.internal", "bar.internal"}, ValidAfter: 1111, ValidBefore: 2222, Permissions: ssh.Permissions{ CriticalOptions: map[string]string{"foo": "bar"}, Extensions: nil, }, Reserved: []byte("reserved"), }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Certificate{ Nonce: tt.fields.Nonce, Key: tt.fields.Key, Serial: tt.fields.Serial, Type: tt.fields.Type, KeyID: tt.fields.KeyID, Principals: tt.fields.Principals, ValidAfter: tt.fields.ValidAfter, ValidBefore: tt.fields.ValidBefore, CriticalOptions: tt.fields.CriticalOptions, Extensions: tt.fields.Extensions, Reserved: tt.fields.Reserved, SignatureKey: tt.fields.SignatureKey, Signature: tt.fields.Signature, } if got := c.GetCertificate(); !reflect.DeepEqual(got, tt.want) { t.Errorf("Certificate.GetCertificate() = %v, want %v", got, tt.want) } }) } } func TestCreateCertificate(t *testing.T) { key, signer := mustGenerateKey(t) type args struct { cert *ssh.Certificate signer ssh.Signer } tests := []struct { name string args args wantErr bool }{ {"ok", args{&ssh.Certificate{ Nonce: []byte("0123456789"), Key: key, Serial: 123, CertType: ssh.HostCert, KeyId: "foo", ValidPrincipals: []string{"foo.internal"}, ValidAfter: 1111, ValidBefore: 2222, Permissions: ssh.Permissions{}, Reserved: []byte("reserved"), }, signer}, false}, {"emptyNonce", args{&ssh.Certificate{ Key: key, Serial: 123, CertType: ssh.UserCert, KeyId: "jane@doe.com", ValidPrincipals: []string{"jane"}, ValidAfter: 1111, ValidBefore: 2222, Permissions: ssh.Permissions{}, Reserved: []byte("reserved"), }, signer}, false}, {"emptySerial", args{&ssh.Certificate{ Nonce: []byte("0123456789"), Key: key, CertType: ssh.UserCert, KeyId: "jane@doe.com", ValidPrincipals: []string{"jane"}, ValidAfter: 1111, ValidBefore: 2222, Permissions: ssh.Permissions{}, Reserved: []byte("reserved"), }, signer}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := CreateCertificate(tt.args.cert, tt.args.signer) if (err != nil) != tt.wantErr { t.Errorf("CreateCertificate() error = %v, wantErr %v", err, tt.wantErr) return } if got != nil { switch { case len(got.Nonce) == 0: t.Errorf("CreateCertificate() nonce should not be empty") case got.Serial == 0: t.Errorf("CreateCertificate() serial should not be 0") case got.Signature == nil: t.Errorf("CreateCertificate() signature should not be nil") case !bytes.Equal(got.SignatureKey.Marshal(), tt.args.signer.PublicKey().Marshal()): t.Errorf("CreateCertificate() signature key is not the expected one") } signature := got.Signature got.Signature = nil data := got.Marshal() data = data[:len(data)-4] sig, err := signer.Sign(rand.Reader, data) if err != nil { t.Errorf("signer.Sign() error = %v", err) } // Verify signature got.Signature = signature if err := signer.PublicKey().Verify(data, got.Signature); err != nil { t.Errorf("CreateCertificate() signature verify error = %v", err) } // Verify data with public key in cert if err := got.Verify(data, sig); err != nil { t.Errorf("CreateCertificate() certificate verify error = %v", err) } } }) } }