Create method CertTypeFromString(s string).
This commit is contained in:
parent
c6746425a3
commit
fdd0eb6773
2 changed files with 47 additions and 8 deletions
|
@ -40,6 +40,18 @@ const (
|
|||
hostString = "host"
|
||||
)
|
||||
|
||||
// CertTypeFromString returns the CertType for the string "user" and "host".
|
||||
func CertTypeFromString(s string) (CertType, error) {
|
||||
switch strings.ToLower(s) {
|
||||
case userString:
|
||||
return UserCert, nil
|
||||
case hostString:
|
||||
return HostCert, nil
|
||||
default:
|
||||
return 0, errors.Errorf("unknown certificate type '%s'", s)
|
||||
}
|
||||
}
|
||||
|
||||
// String returns "user" for user certificates and "host" for host certificates.
|
||||
// It will return the empty string for any other value.
|
||||
func (c CertType) String() string {
|
||||
|
@ -68,14 +80,10 @@ func (c *CertType) UnmarshalJSON(data []byte) error {
|
|||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return errors.Wrap(err, "error unmarshaling certificate type")
|
||||
}
|
||||
switch strings.ToLower(s) {
|
||||
case userString:
|
||||
*c = UserCert
|
||||
return nil
|
||||
case hostString:
|
||||
*c = HostCert
|
||||
return nil
|
||||
default:
|
||||
certType, err := CertTypeFromString(s)
|
||||
if err != nil {
|
||||
return errors.Errorf("error unmarshaling '%s' as a certificate type", s)
|
||||
}
|
||||
*c = certType
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,6 +5,37 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func TestCertTypeFromString(t *testing.T) {
|
||||
type args struct {
|
||||
s string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want CertType
|
||||
wantErr bool
|
||||
}{
|
||||
{"user", args{"user"}, UserCert, false},
|
||||
{"USER", args{"USER"}, UserCert, false},
|
||||
{"host", args{"host"}, HostCert, false},
|
||||
{"Host", args{"Host"}, HostCert, false},
|
||||
{" user ", args{" user "}, 0, true},
|
||||
{"invalid", args{"invalid"}, 0, true},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := CertTypeFromString(tt.args.s)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CertTypeFromString() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("CertTypeFromString() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertType_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
Loading…
Reference in a new issue