make Duration public

This commit is contained in:
max furman 2019-01-20 21:33:14 -08:00
parent 0615f7eb11
commit 6dc89f46d8
4 changed files with 33 additions and 37 deletions

View file

@ -29,9 +29,9 @@ var (
maxTLSDur = 24 * time.Hour maxTLSDur = 24 * time.Hour
defaultTLSDur = 24 * time.Hour defaultTLSDur = 24 * time.Hour
globalProvisionerClaims = ProvisionerClaims{ globalProvisionerClaims = ProvisionerClaims{
MinTLSDur: (*duration)(&minTLSDur), MinTLSDur: &Duration{5 * time.Minute},
MaxTLSDur: (*duration)(&maxTLSDur), MaxTLSDur: &Duration{24 * time.Hour},
DefaultTLSDur: (*duration)(&defaultTLSDur), DefaultTLSDur: &Duration{24 * time.Hour},
DisableRenewal: &defaultDisableRenewal, DisableRenewal: &defaultDisableRenewal,
} }
) )

View file

@ -12,9 +12,9 @@ import (
// ProvisionerClaims so that individual provisioners can override global claims. // ProvisionerClaims so that individual provisioners can override global claims.
type ProvisionerClaims struct { type ProvisionerClaims struct {
globalClaims *ProvisionerClaims globalClaims *ProvisionerClaims
MinTLSDur *duration `json:"minTLSCertDuration,omitempty"` MinTLSDur *Duration `json:"minTLSCertDuration,omitempty"`
MaxTLSDur *duration `json:"maxTLSCertDuration,omitempty"` MaxTLSDur *Duration `json:"maxTLSCertDuration,omitempty"`
DefaultTLSDur *duration `json:"defaultTLSCertDuration,omitempty"` DefaultTLSDur *Duration `json:"defaultTLSCertDuration,omitempty"`
DisableRenewal *bool `json:"disableRenewal,omitempty"` DisableRenewal *bool `json:"disableRenewal,omitempty"`
} }
@ -32,30 +32,30 @@ func (pc *ProvisionerClaims) Init(global *ProvisionerClaims) (*ProvisionerClaims
// provisioner. If the default is not set within the provisioner, then the global // provisioner. If the default is not set within the provisioner, then the global
// default from the authority configuration will be used. // default from the authority configuration will be used.
func (pc *ProvisionerClaims) DefaultTLSCertDuration() time.Duration { func (pc *ProvisionerClaims) DefaultTLSCertDuration() time.Duration {
if pc.DefaultTLSDur == nil || *pc.DefaultTLSDur == 0 { if pc.DefaultTLSDur == nil || pc.DefaultTLSDur.Duration == 0 {
return pc.globalClaims.DefaultTLSCertDuration() return pc.globalClaims.DefaultTLSCertDuration()
} }
return time.Duration(*pc.DefaultTLSDur) return pc.DefaultTLSDur.Duration
} }
// MinTLSCertDuration returns the minimum TLS cert duration for the provisioner. // MinTLSCertDuration returns the minimum TLS cert duration for the provisioner.
// If the minimum is not set within the provisioner, then the global // If the minimum is not set within the provisioner, then the global
// minimum from the authority configuration will be used. // minimum from the authority configuration will be used.
func (pc *ProvisionerClaims) MinTLSCertDuration() time.Duration { func (pc *ProvisionerClaims) MinTLSCertDuration() time.Duration {
if pc.MinTLSDur == nil || *pc.MinTLSDur == 0 { if pc.MinTLSDur == nil || pc.MinTLSDur.Duration == 0 {
return pc.globalClaims.MinTLSCertDuration() return pc.globalClaims.MinTLSCertDuration()
} }
return time.Duration(*pc.MinTLSDur) return pc.MinTLSDur.Duration
} }
// MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner. // MaxTLSCertDuration returns the maximum TLS cert duration for the provisioner.
// If the maximum is not set within the provisioner, then the global // If the maximum is not set within the provisioner, then the global
// maximum from the authority configuration will be used. // maximum from the authority configuration will be used.
func (pc *ProvisionerClaims) MaxTLSCertDuration() time.Duration { func (pc *ProvisionerClaims) MaxTLSCertDuration() time.Duration {
if pc.MaxTLSDur == nil || *pc.MaxTLSDur == 0 { if pc.MaxTLSDur == nil || pc.MaxTLSDur.Duration == 0 {
return pc.globalClaims.MaxTLSCertDuration() return pc.globalClaims.MaxTLSCertDuration()
} }
return time.Duration(*pc.MaxTLSDur) return pc.MaxTLSDur.Duration
} }
// IsDisableRenewal returns if the renewal flow is disabled for the // IsDisableRenewal returns if the renewal flow is disabled for the

View file

@ -8,15 +8,17 @@ import (
) )
// Duration is a wrapper around Time.Duration to aid with marshal/unmarshal. // Duration is a wrapper around Time.Duration to aid with marshal/unmarshal.
type duration time.Duration type Duration struct {
time.Duration
}
// MarshalJSON parses a duration string and sets it to the duration. // MarshalJSON parses a duration string and sets it to the duration.
// //
// A duration string is a possibly signed sequence of decimal numbers, each with // A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". // optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) MarshalJSON() ([]byte, error) { func (d *Duration) MarshalJSON() ([]byte, error) {
return json.Marshal((*time.Duration)(d).String()) return json.Marshal(d.Duration.String())
} }
// UnmarshalJSON parses a duration string and sets it to the duration. // UnmarshalJSON parses a duration string and sets it to the duration.
@ -24,7 +26,7 @@ func (d *duration) MarshalJSON() ([]byte, error) {
// A duration string is a possibly signed sequence of decimal numbers, each with // A duration string is a possibly signed sequence of decimal numbers, each with
// optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". // optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m".
// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h".
func (d *duration) UnmarshalJSON(data []byte) (err error) { func (d *Duration) UnmarshalJSON(data []byte) (err error) {
var ( var (
s string s string
_d time.Duration _d time.Duration
@ -38,7 +40,7 @@ func (d *duration) UnmarshalJSON(data []byte) (err error) {
if _d, err = time.ParseDuration(s); err != nil { if _d, err = time.ParseDuration(s); err != nil {
return errors.Wrapf(err, "error parsing %s as duration", s) return errors.Wrapf(err, "error parsing %s as duration", s)
} }
*d = duration(_d) d.Duration = _d
return return
} }

View file

@ -102,38 +102,32 @@ func Test_multiString_UnmarshalJSON(t *testing.T) {
} }
} }
func durPtr(_d time.Duration) *duration { func TestDuration_UnmarshalJSON(t *testing.T) {
d := new(duration)
*d = duration(_d)
return d
}
func Test_duration_UnmarshalJSON(t *testing.T) {
type args struct { type args struct {
data []byte data []byte
} }
tests := []struct { tests := []struct {
name string name string
d *duration d *Duration
args args args args
want *duration want *Duration
wantErr bool wantErr bool
}{ }{
{"empty", new(duration), args{[]byte{}}, new(duration), true}, {"empty", new(Duration), args{[]byte{}}, new(Duration), true},
{"bad type", new(duration), args{[]byte(`15`)}, new(duration), true}, {"bad type", new(Duration), args{[]byte(`15`)}, new(Duration), true},
{"empty string", new(duration), args{[]byte(`""`)}, new(duration), true}, {"empty string", new(Duration), args{[]byte(`""`)}, new(Duration), true},
{"non duration", new(duration), args{[]byte(`"15"`)}, new(duration), true}, {"non duration", new(Duration), args{[]byte(`"15"`)}, new(Duration), true},
{"duration", new(duration), args{[]byte(`"15m30s"`)}, durPtr(15*time.Minute + 30*time.Second), false}, {"duration", new(Duration), args{[]byte(`"15m30s"`)}, &Duration{15*time.Minute + 30*time.Second}, false},
{"nil", nil, args{nil}, nil, true}, {"nil", nil, args{nil}, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { if err := tt.d.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr {
t.Errorf("multiString.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Duration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(tt.d, tt.want) { if !reflect.DeepEqual(tt.d, tt.want) {
t.Errorf("multiString.UnmarshalJSON() = %v, want %v", tt.d, tt.want) t.Errorf("Duration.UnmarshalJSON() = %v, want %v", tt.d, tt.want)
} }
}) })
} }
@ -142,21 +136,21 @@ func Test_duration_UnmarshalJSON(t *testing.T) {
func Test_duration_MarshalJSON(t *testing.T) { func Test_duration_MarshalJSON(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
d *duration d *Duration
want []byte want []byte
wantErr bool wantErr bool
}{ }{
{"string", durPtr(15*time.Minute + 30*time.Second), []byte(`"15m30s"`), false}, {"string", &Duration{15*time.Minute + 30*time.Second}, []byte(`"15m30s"`), false},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tt.d.MarshalJSON() got, err := tt.d.MarshalJSON()
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Duration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("duration.MarshalJSON() = %v, want %v", got, tt.want) t.Errorf("Duration.MarshalJSON() = %v, want %v", got, tt.want)
} }
}) })
} }