diff --git a/sshutil/types.go b/sshutil/types.go index bf62933b..d1c45564 100644 --- a/sshutil/types.go +++ b/sshutil/types.go @@ -40,6 +40,18 @@ const ( hostString = "host" ) +// CertTypeFromString returns the CertType for the string "user" and "host". +func CertTypeFromString(s string) (CertType, error) { + switch strings.ToLower(s) { + case userString: + return UserCert, nil + case hostString: + return HostCert, nil + default: + return 0, errors.Errorf("unknown certificate type '%s'", s) + } +} + // String returns "user" for user certificates and "host" for host certificates. // It will return the empty string for any other value. func (c CertType) String() string { @@ -68,14 +80,10 @@ func (c *CertType) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &s); err != nil { return errors.Wrap(err, "error unmarshaling certificate type") } - switch strings.ToLower(s) { - case userString: - *c = UserCert - return nil - case hostString: - *c = HostCert - return nil - default: + certType, err := CertTypeFromString(s) + if err != nil { return errors.Errorf("error unmarshaling '%s' as a certificate type", s) } + *c = certType + return nil } diff --git a/sshutil/types_test.go b/sshutil/types_test.go index 15306554..556b461c 100644 --- a/sshutil/types_test.go +++ b/sshutil/types_test.go @@ -5,6 +5,37 @@ import ( "testing" ) +func TestCertTypeFromString(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want CertType + wantErr bool + }{ + {"user", args{"user"}, UserCert, false}, + {"USER", args{"USER"}, UserCert, false}, + {"host", args{"host"}, HostCert, false}, + {"Host", args{"Host"}, HostCert, false}, + {" user ", args{" user "}, 0, true}, + {"invalid", args{"invalid"}, 0, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := CertTypeFromString(tt.args.s) + if (err != nil) != tt.wantErr { + t.Errorf("CertTypeFromString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("CertTypeFromString() = %v, want %v", got, tt.want) + } + }) + } +} + func TestCertType_String(t *testing.T) { tests := []struct { name string