Create method CertTypeFromString(s string).

This commit is contained in:
Mariano Cano 2020-07-27 11:06:51 -07:00
parent c6746425a3
commit fdd0eb6773
2 changed files with 47 additions and 8 deletions

View file

@ -40,6 +40,18 @@ const (
hostString = "host" hostString = "host"
) )
// CertTypeFromString returns the CertType for the string "user" and "host".
func CertTypeFromString(s string) (CertType, error) {
switch strings.ToLower(s) {
case userString:
return UserCert, nil
case hostString:
return HostCert, nil
default:
return 0, errors.Errorf("unknown certificate type '%s'", s)
}
}
// String returns "user" for user certificates and "host" for host certificates. // String returns "user" for user certificates and "host" for host certificates.
// It will return the empty string for any other value. // It will return the empty string for any other value.
func (c CertType) String() string { func (c CertType) String() string {
@ -68,14 +80,10 @@ func (c *CertType) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, &s); err != nil { if err := json.Unmarshal(data, &s); err != nil {
return errors.Wrap(err, "error unmarshaling certificate type") return errors.Wrap(err, "error unmarshaling certificate type")
} }
switch strings.ToLower(s) { certType, err := CertTypeFromString(s)
case userString: if err != nil {
*c = UserCert
return nil
case hostString:
*c = HostCert
return nil
default:
return errors.Errorf("error unmarshaling '%s' as a certificate type", s) return errors.Errorf("error unmarshaling '%s' as a certificate type", s)
} }
*c = certType
return nil
} }

View file

@ -5,6 +5,37 @@ import (
"testing" "testing"
) )
func TestCertTypeFromString(t *testing.T) {
type args struct {
s string
}
tests := []struct {
name string
args args
want CertType
wantErr bool
}{
{"user", args{"user"}, UserCert, false},
{"USER", args{"USER"}, UserCert, false},
{"host", args{"host"}, HostCert, false},
{"Host", args{"Host"}, HostCert, false},
{" user ", args{" user "}, 0, true},
{"invalid", args{"invalid"}, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CertTypeFromString(tt.args.s)
if (err != nil) != tt.wantErr {
t.Errorf("CertTypeFromString() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("CertTypeFromString() = %v, want %v", got, tt.want)
}
})
}
}
func TestCertType_String(t *testing.T) { func TestCertType_String(t *testing.T) {
tests := []struct { tests := []struct {
name string name string