diff --git a/api/api.go b/api/api.go index a92b7902..3f5b54ab 100644 --- a/api/api.go +++ b/api/api.go @@ -36,6 +36,20 @@ type Authority interface { GetFederation() ([]*x509.Certificate, error) } +// TimeDuration is an alias of provisioner.TimeDuration +type TimeDuration = provisioner.TimeDuration + +// NewTimeDuration returns a TimeDuration with the defined time. +func NewTimeDuration(t time.Time) TimeDuration { + return provisioner.NewTimeDuration(t) +} + +// ParseTimeDuration returns a new TimeDuration parsing the RFC 3339 time or +// time.Duration string. +func ParseTimeDuration(s string) (TimeDuration, error) { + return provisioner.ParseTimeDuration(s) +} + // Certificate wraps a *x509.Certificate and adds the json.Marshaler interface. type Certificate struct { *x509.Certificate @@ -154,8 +168,8 @@ type RootResponse struct { type SignRequest struct { CsrPEM CertificateRequest `json:"csr"` OTT string `json:"ott"` - NotAfter time.Time `json:"notAfter"` - NotBefore time.Time `json:"notBefore"` + NotAfter TimeDuration `json:"notAfter"` + NotBefore TimeDuration `json:"notBefore"` } // ProvisionersResponse is the response object that returns the list of diff --git a/api/api_test.go b/api/api_test.go index 80879ef5..e78b370e 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -397,8 +397,8 @@ func TestSignRequest_Validate(t *testing.T) { s := &SignRequest{ CsrPEM: tt.fields.CsrPEM, OTT: tt.fields.OTT, - NotAfter: tt.fields.NotAfter, - NotBefore: tt.fields.NotBefore, + NotAfter: NewTimeDuration(tt.fields.NotAfter), + NotBefore: NewTimeDuration(tt.fields.NotBefore), } if err := s.Validate(); (err != nil) != tt.wantErr { t.Errorf("SignRequest.Validate() error = %v, wantErr %v", err, tt.wantErr) diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index c28fd80b..b8b4e51a 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -14,8 +14,8 @@ import ( // Options contains the options that can be passed to the Sign method. type Options struct { - NotAfter time.Time `json:"notAfter"` - NotBefore time.Time `json:"notBefore"` + NotAfter TimeDuration `json:"notAfter"` + NotBefore TimeDuration `json:"notBefore"` } // SignOption is the interface used to collect all extra options used in the @@ -55,7 +55,7 @@ func (v profileWithOption) Option(Options) x509util.WithOption { type profileDefaultDuration time.Duration func (v profileDefaultDuration) Option(so Options) x509util.WithOption { - return x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, time.Duration(v)) + return x509util.WithNotBeforeAfterDuration(so.NotBefore.Time(), so.NotAfter.Time(), time.Duration(v)) } // emailOnlyIdentity is a CertificateRequestValidator that checks that the only @@ -228,6 +228,6 @@ func createProvisionerExtension(typ int, name, credentialID string) (pkix.Extens } func init() { - // Avoid deadcode warning in profileWithOption + // Avoid dead-code warning in profileWithOption _ = profileWithOption(nil) } diff --git a/authority/provisioner/timeduration.go b/authority/provisioner/timeduration.go new file mode 100644 index 00000000..e3152808 --- /dev/null +++ b/authority/provisioner/timeduration.go @@ -0,0 +1,124 @@ +package provisioner + +import ( + "encoding/json" + "time" + + "github.com/pkg/errors" +) + +var now = func() time.Time { + return time.Now().UTC() +} + +// TimeDuration is a type that represents a time but the JSON unmarshaling can +// use a time using the RFC 3339 format or a time.Duration string. If a duration +// is used, the time will be set on the first call to TimeDuration.Time. +type TimeDuration struct { + t time.Time + d time.Duration +} + +// NewTimeDuration returns a TimeDuration with the defined time. +func NewTimeDuration(t time.Time) TimeDuration { + return TimeDuration{t: t} +} + +// ParseTimeDuration returns a new TimeDuration parsing the RFC 3339 time or +// time.Duration string. +func ParseTimeDuration(s string) (TimeDuration, error) { + if s == "" { + return TimeDuration{}, nil + } + + // Try to use the unquoted RFC 3339 format + var t time.Time + if err := t.UnmarshalText([]byte(s)); err == nil { + return TimeDuration{t: t.UTC()}, nil + } + + // Try to use the time.Duration string format + if d, err := time.ParseDuration(s); err == nil { + return TimeDuration{d: d}, nil + } + + return TimeDuration{}, errors.Errorf("failed to parse %s", s) +} + +// SetDuration initializes the TimeDuration with the given duration string. If +// the time was set it will re-set to zero. +func (t *TimeDuration) SetDuration(d time.Duration) { + t.t, t.d = time.Time{}, d +} + +// SetTime initializes the TimeDuration with the given time. If the duration is +// set it will be re-set to zero. +func (t *TimeDuration) SetTime(tt time.Time) { + t.t, t.d = tt, 0 +} + +// MarshalJSON implements the json.Marshaler interface. If the time is set it +// will return the time in RFC 3339 format if not it will return the duration +// string. +func (t TimeDuration) MarshalJSON() ([]byte, error) { + switch { + case t.t.IsZero(): + if t.d == 0 { + return []byte("null"), nil + } + return json.Marshal(t.d.String()) + default: + return t.t.MarshalJSON() + } +} + +// UnmarshalJSON implements the json.Unmarshaler interface. The time is expected +// to be a quoted string in RFC 3339 format or a quoted time.Duration string. +func (t *TimeDuration) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return errors.Wrapf(err, "error unmarshaling %s", data) + } + + // Empty TimeDuration + if s == "" { + *t = TimeDuration{} + return nil + } + + // Try to use the unquoted RFC 3339 format + var tt time.Time + if err := tt.UnmarshalText([]byte(s)); err == nil { + *t = TimeDuration{t: tt} + return nil + } + + // Try to use the time.Duration string format + if d, err := time.ParseDuration(s); err == nil { + *t = TimeDuration{d: d} + return nil + } + + return errors.Errorf("failed to parse %s", data) +} + +// Time calculates the embedded time.Time, sets it if necessary, and returns it. +func (t *TimeDuration) Time() time.Time { + switch { + case t == nil: + return time.Time{} + case t.t.IsZero(): + if t.d == 0 { + return time.Time{} + } + t.t = now().Add(t.d) + return t.t + default: + return t.t.UTC() + } +} + +// String implements the fmt.Stringer interface. +func (t *TimeDuration) String() string { + return t.Time().String() +} diff --git a/authority/provisioner/timeduration_test.go b/authority/provisioner/timeduration_test.go new file mode 100644 index 00000000..97dd4ce5 --- /dev/null +++ b/authority/provisioner/timeduration_test.go @@ -0,0 +1,251 @@ +package provisioner + +import ( + "reflect" + "testing" + "time" +) + +func TestNewTimeDuration(t *testing.T) { + tm := time.Unix(1584198566, 535897000).UTC() + type args struct { + t time.Time + } + tests := []struct { + name string + args args + want TimeDuration + }{ + {"ok", args{tm}, TimeDuration{t: tm}}, + {"zero", args{time.Time{}}, TimeDuration{}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewTimeDuration(tt.args.t); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewTimeDuration() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParseTimeDuration(t *testing.T) { + type args struct { + s string + } + tests := []struct { + name string + args args + want TimeDuration + wantErr bool + }{ + {"timestamp", args{"2020-03-14T15:09:26.535897Z"}, TimeDuration{t: time.Unix(1584198566, 535897000).UTC()}, false}, + {"timestamp", args{"2020-03-14T15:09:26Z"}, TimeDuration{t: time.Unix(1584198566, 0).UTC()}, false}, + {"timestamp", args{"2020-03-14T15:09:26.535897-07:00"}, TimeDuration{t: time.Unix(1584223766, 535897000).UTC()}, false}, + {"timestamp", args{"2020-03-14T15:09:26-07:00"}, TimeDuration{t: time.Unix(1584223766, 0).UTC()}, false}, + {"timestamp", args{"2020-03-14T15:09:26.535897+07:00"}, TimeDuration{t: time.Unix(1584173366, 535897000).UTC()}, false}, + {"timestamp", args{"2020-03-14T15:09:26+07:00"}, TimeDuration{t: time.Unix(1584173366, 0).UTC()}, false}, + {"1h", args{"1h"}, TimeDuration{d: 1 * time.Hour}, false}, + {"-24h60m60s", args{"-24h60m60s"}, TimeDuration{d: -24*time.Hour - 60*time.Minute - 60*time.Second}, false}, + {"0", args{"0"}, TimeDuration{}, false}, + {"empty", args{""}, TimeDuration{}, false}, + {"fail", args{"2020-03-14T15:09:26Z07:00"}, TimeDuration{}, true}, + {"fail", args{"1d"}, TimeDuration{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseTimeDuration(tt.args.s) + if (err != nil) != tt.wantErr { + t.Errorf("ParseTimeDuration() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseTimeDuration() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTimeDuration_SetDuration(t *testing.T) { + type fields struct { + t time.Time + d time.Duration + } + type args struct { + d time.Duration + } + tests := []struct { + name string + fields fields + args args + want *TimeDuration + }{ + {"new", fields{}, args{2 * time.Hour}, &TimeDuration{d: 2 * time.Hour}}, + {"old", fields{time.Now(), 1 * time.Hour}, args{2 * time.Hour}, &TimeDuration{d: 2 * time.Hour}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + td := &TimeDuration{ + t: tt.fields.t, + d: tt.fields.d, + } + td.SetDuration(tt.args.d) + if !reflect.DeepEqual(td, tt.want) { + t.Errorf("SetDuration() = %v, want %v", td, tt.want) + } + }) + } +} + +func TestTimeDuration_SetTime(t *testing.T) { + tm := time.Unix(1584198566, 535897000).UTC() + + type fields struct { + t time.Time + d time.Duration + } + type args struct { + tt time.Time + } + tests := []struct { + name string + fields fields + args args + want *TimeDuration + }{ + {"new", fields{}, args{tm}, &TimeDuration{t: tm}}, + {"old", fields{time.Now(), 1 * time.Hour}, args{tm}, &TimeDuration{t: tm}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + td := &TimeDuration{ + t: tt.fields.t, + d: tt.fields.d, + } + td.SetTime(tt.args.tt) + if !reflect.DeepEqual(td, tt.want) { + t.Errorf("SetTime() = %v, want %v", td, tt.want) + } + }) + } +} + +func TestTimeDuration_MarshalJSON(t *testing.T) { + tm := time.Unix(1584198566, 535897000).UTC() + tests := []struct { + name string + timeDuration TimeDuration + want []byte + wantErr bool + }{ + {"null", TimeDuration{}, []byte("null"), false}, + {"timestamp", TimeDuration{t: tm}, []byte(`"2020-03-14T15:09:26.535897Z"`), false}, + {"duration", TimeDuration{d: 1 * time.Hour}, []byte(`"1h0m0s"`), false}, + {"fail", TimeDuration{t: time.Date(-1, 0, 0, 0, 0, 0, 0, time.UTC)}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.timeDuration.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("TimeDuration.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("TimeDuration.MarshalJSON() = %s, want %s", got, tt.want) + } + }) + } +} + +func TestTimeDuration_UnmarshalJSON(t *testing.T) { + type args struct { + data []byte + } + tests := []struct { + name string + args args + want *TimeDuration + wantErr bool + }{ + {"null", args{[]byte("null")}, &TimeDuration{}, false}, + {"timestamp", args{[]byte(`"2020-03-14T15:09:26.535897Z"`)}, &TimeDuration{t: time.Unix(1584198566, 535897000).UTC()}, false}, + {"duration", args{[]byte(`"1h"`)}, &TimeDuration{d: time.Hour}, false}, + {"fail", args{[]byte("123")}, &TimeDuration{}, true}, + {"fail", args{[]byte(`"2020-03-14T15:09:26.535897Z07:00"`)}, &TimeDuration{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + td := &TimeDuration{} + if err := td.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("TimeDuration.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(td, tt.want) { + t.Errorf("TimeDuration.UnmarshalJSON() = %s, want %s", td, tt.want) + } + }) + } +} + +func TestTimeDuration_Time(t *testing.T) { + nowFn := now + defer func() { + now = nowFn + now() + }() + tm := time.Unix(1584198566, 535897000).UTC() + now = func() time.Time { + return tm + } + tests := []struct { + name string + timeDuration *TimeDuration + want time.Time + }{ + {"zero", nil, time.Time{}}, + {"zero", &TimeDuration{}, time.Time{}}, + {"timestamp", &TimeDuration{t: tm}, tm}, + {"local", &TimeDuration{t: tm.Local()}, tm}, + {"duration", &TimeDuration{d: 1 * time.Hour}, tm.Add(1 * time.Hour)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.timeDuration.Time() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("TimeDuration.Time() = %v, want %v", got, tt.want) + + } + }) + } +} + +func TestTimeDuration_String(t *testing.T) { + nowFn := now + defer func() { + now = nowFn + now() + }() + tm := time.Unix(1584198566, 535897000).UTC() + now = func() time.Time { + return tm + } + type fields struct { + t time.Time + d time.Duration + } + tests := []struct { + name string + timeDuration *TimeDuration + want string + }{ + {"zero", nil, "0001-01-01 00:00:00 +0000 UTC"}, + {"zero", &TimeDuration{}, "0001-01-01 00:00:00 +0000 UTC"}, + {"timestamp", &TimeDuration{t: tm}, "2020-03-14 15:09:26.535897 +0000 UTC"}, + {"duration", &TimeDuration{d: 1 * time.Hour}, "2020-03-14 16:09:26.535897 +0000 UTC"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.timeDuration.String(); got != tt.want { + t.Errorf("TimeDuration.String() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/authority/tls_test.go b/authority/tls_test.go index 47ac7966..eb1793e2 100644 --- a/authority/tls_test.go +++ b/authority/tls_test.go @@ -89,8 +89,8 @@ func TestSign(t *testing.T) { nb := time.Now() signOpts := provisioner.Options{ - NotBefore: nb, - NotAfter: nb.Add(time.Minute * 5), + NotBefore: provisioner.NewTimeDuration(nb), + NotAfter: provisioner.NewTimeDuration(nb.Add(time.Minute * 5)), } // Create a token to get test extra opts. @@ -171,8 +171,8 @@ func TestSign(t *testing.T) { "fail provisioner duration claim": func(t *testing.T) *signTest { csr := getCSR(t, priv) _signOpts := provisioner.Options{ - NotBefore: nb, - NotAfter: nb.Add(time.Hour * 25), + NotBefore: provisioner.NewTimeDuration(nb), + NotAfter: provisioner.NewTimeDuration(nb.Add(time.Hour * 25)), } return &signTest{ auth: a, @@ -229,8 +229,8 @@ func TestSign(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, leaf.NotBefore, signOpts.NotBefore.UTC().Truncate(time.Second)) - assert.Equals(t, leaf.NotAfter, signOpts.NotAfter.UTC().Truncate(time.Second)) + assert.Equals(t, leaf.NotBefore, signOpts.NotBefore.Time().Truncate(time.Second)) + assert.Equals(t, leaf.NotAfter, signOpts.NotAfter.Time().Truncate(time.Second)) tmplt := a.config.AuthorityConfig.Template assert.Equals(t, fmt.Sprintf("%v", leaf.Subject), fmt.Sprintf("%v", &pkix.Name{ @@ -300,13 +300,13 @@ func TestRenew(t *testing.T) { nb1 := now.Add(-time.Minute * 7) na1 := now so := &provisioner.Options{ - NotBefore: nb1, - NotAfter: na1, + NotBefore: provisioner.NewTimeDuration(nb1), + NotAfter: provisioner.NewTimeDuration(na1), } leaf, err := x509util.NewLeafProfile("renew", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, - x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0), + x509util.WithNotBeforeAfterDuration(so.NotBefore.Time(), so.NotAfter.Time(), 0), withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"), withProvisionerOID("Max", a.config.AuthorityConfig.Provisioners[0].(*provisioner.JWK).Key.KeyID)) @@ -318,7 +318,7 @@ func TestRenew(t *testing.T) { leafNoRenew, err := x509util.NewLeafProfile("norenew", a.intermediateIdentity.Crt, a.intermediateIdentity.Key, - x509util.WithNotBeforeAfterDuration(so.NotBefore, so.NotAfter, 0), + x509util.WithNotBeforeAfterDuration(so.NotBefore.Time(), so.NotAfter.Time(), 0), withDefaultASN1DN(a.config.AuthorityConfig.Template), x509util.WithPublicKey(pub), x509util.WithHosts("test.smallstep.com,test"), withProvisionerOID("dev", a.config.AuthorityConfig.Provisioners[2].(*provisioner.JWK).Key.KeyID), diff --git a/ca/ca_test.go b/ca/ca_test.go index d5fc17f7..cbbd6d48 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -209,8 +209,8 @@ ZEp7knvU2psWRw== body, err := json.Marshal(&api.SignRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: csr}, OTT: raw, - NotBefore: now, - NotAfter: leafExpiry, + NotBefore: api.NewTimeDuration(now), + NotAfter: api.NewTimeDuration(leafExpiry), }) assert.FatalError(t, err) return &signTest{ @@ -242,8 +242,8 @@ ZEp7knvU2psWRw== body, err := json.Marshal(&api.SignRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: csr}, OTT: raw, - NotBefore: now, - NotAfter: leafExpiry, + NotBefore: api.NewTimeDuration(now), + NotAfter: api.NewTimeDuration(leafExpiry), }) assert.FatalError(t, err) return &signTest{ diff --git a/ca/client_test.go b/ca/client_test.go index 68fefd09..bfac97a4 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -257,8 +257,8 @@ func TestClient_Sign(t *testing.T) { request := &api.SignRequest{ CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, OTT: "the-ott", - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(0, 1, 0), + NotBefore: api.NewTimeDuration(time.Now()), + NotAfter: api.NewTimeDuration(time.Now().AddDate(0, 1, 0)), } unauthorized := api.Unauthorized(fmt.Errorf("Unauthorized")) badRequest := api.BadRequest(fmt.Errorf("Bad Request")) diff --git a/ca/tls_test.go b/ca/tls_test.go index c71e839d..b88e825a 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -95,8 +95,8 @@ func signDuration(srv *httptest.Server, domain string, duration time.Duration) ( } if duration > 0 { - req.NotBefore = time.Now() - req.NotAfter = req.NotBefore.Add(duration) + req.NotBefore = api.NewTimeDuration(time.Now()) + req.NotAfter = api.NewTimeDuration(req.NotBefore.Time().Add(duration)) } client, err := NewClient(srv.URL, WithRootFile("testdata/secrets/root_ca.crt"))