package provisioner

import (
	"reflect"
	"testing"
	"time"

	"github.com/pkg/errors"
	"github.com/smallstep/assert"
	"go.step.sm/crypto/keyutil"
	"golang.org/x/crypto/ssh"
)

func TestSSHOptions_Type(t *testing.T) {
	type fields struct {
		CertType string
	}
	tests := []struct {
		name   string
		fields fields
		want   uint32
	}{
		{"user", fields{"user"}, 1},
		{"host", fields{"host"}, 2},
		{"empty", fields{""}, 0},
		{"invalid", fields{"invalid"}, 0},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			o := SignSSHOptions{
				CertType: tt.fields.CertType,
			}
			if got := o.Type(); got != tt.want {
				t.Errorf("SSHOptions.Type() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestSSHOptions_Modify(t *testing.T) {
	type test struct {
		so    SignSSHOptions
		cert  *ssh.Certificate
		valid func(*ssh.Certificate)
		err   error
	}
	tests := map[string](func() test){
		"fail/unexpected-cert-type": func() test {
			return test{
				so:   SignSSHOptions{CertType: "foo"},
				cert: new(ssh.Certificate),
				err:  errors.Errorf("ssh certificate has an unknown type - foo"),
			}
		},
		"fail/validAfter-greater-validBefore": func() test {
			return test{
				so:   SignSSHOptions{CertType: "user"},
				cert: &ssh.Certificate{ValidAfter: uint64(15), ValidBefore: uint64(10)},
				err:  errors.Errorf("ssh certificate valid after cannot be greater than valid before"),
			}
		},
		"ok/user-cert": func() test {
			return test{
				so:   SignSSHOptions{CertType: "user"},
				cert: new(ssh.Certificate),
				valid: func(cert *ssh.Certificate) {
					assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
				},
			}
		},
		"ok/host-cert": func() test {
			return test{
				so:   SignSSHOptions{CertType: "host"},
				cert: new(ssh.Certificate),
				valid: func(cert *ssh.Certificate) {
					assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
				},
			}
		},
		"ok": func() test {
			va := time.Now().Add(5 * time.Minute)
			vb := time.Now().Add(1 * time.Hour)
			so := SignSSHOptions{CertType: "host", KeyID: "foo", Principals: []string{"foo", "bar"},
				ValidAfter: NewTimeDuration(va), ValidBefore: NewTimeDuration(vb)}
			return test{
				so:   so,
				cert: new(ssh.Certificate),
				valid: func(cert *ssh.Certificate) {
					assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
					assert.Equals(t, cert.KeyId, so.KeyID)
					assert.Equals(t, cert.ValidPrincipals, so.Principals)
					assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix()))
					assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix()))
				},
			}
		},
	}
	for name, run := range tests {
		t.Run(name, func(t *testing.T) {
			tc := run()
			if err := tc.so.Modify(tc.cert, tc.so); err != nil {
				if assert.NotNil(t, tc.err) {
					assert.HasPrefix(t, err.Error(), tc.err.Error())
				}
			} else {
				if assert.Nil(t, tc.err) {
					tc.valid(tc.cert)
				}
			}
		})
	}
}

func TestSSHOptions_Match(t *testing.T) {
	type test struct {
		so  SignSSHOptions
		cmp SignSSHOptions
		err error
	}
	tests := map[string](func() test){
		"fail/cert-type": func() test {
			return test{
				so:  SignSSHOptions{CertType: "foo"},
				cmp: SignSSHOptions{CertType: "bar"},
				err: errors.Errorf("ssh certificate type does not match - got bar, want foo"),
			}
		},
		"fail/pricipals": func() test {
			return test{
				so:  SignSSHOptions{Principals: []string{"foo"}},
				cmp: SignSSHOptions{Principals: []string{"bar"}},
				err: errors.Errorf("ssh certificate principals does not match - got [bar], want [foo]"),
			}
		},
		"fail/validAfter": func() test {
			return test{
				so:  SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(1 * time.Minute))},
				cmp: SignSSHOptions{ValidAfter: NewTimeDuration(time.Now().Add(5 * time.Minute))},
				err: errors.Errorf("ssh certificate valid after does not match"),
			}
		},
		"fail/validBefore": func() test {
			return test{
				so:  SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(1 * time.Minute))},
				cmp: SignSSHOptions{ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute))},
				err: errors.Errorf("ssh certificate valid before does not match"),
			}
		},
		"ok/original-empty": func() test {
			return test{
				so: SignSSHOptions{},
				cmp: SignSSHOptions{
					CertType:    "foo",
					Principals:  []string{"foo"},
					ValidAfter:  NewTimeDuration(time.Now().Add(1 * time.Minute)),
					ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)),
				},
			}
		},
		"ok/cmp-empty": func() test {
			return test{
				cmp: SignSSHOptions{},
				so: SignSSHOptions{
					CertType:    "foo",
					Principals:  []string{"foo"},
					ValidAfter:  NewTimeDuration(time.Now().Add(1 * time.Minute)),
					ValidBefore: NewTimeDuration(time.Now().Add(5 * time.Minute)),
				},
			}
		},
		"ok/equal": func() test {
			n := time.Now()
			va := NewTimeDuration(n.Add(1 * time.Minute))
			vb := NewTimeDuration(n.Add(5 * time.Minute))
			return test{
				cmp: SignSSHOptions{
					CertType:    "foo",
					Principals:  []string{"foo"},
					ValidAfter:  va,
					ValidBefore: vb,
				},
				so: SignSSHOptions{
					CertType:    "foo",
					Principals:  []string{"foo"},
					ValidAfter:  va,
					ValidBefore: vb,
				},
			}
		},
	}
	for name, run := range tests {
		t.Run(name, func(t *testing.T) {
			tc := run()
			if err := tc.so.match(tc.cmp); err != nil {
				if assert.NotNil(t, tc.err) {
					assert.HasPrefix(t, err.Error(), tc.err.Error())
				}
			} else {
				assert.Nil(t, tc.err)
			}
		})
	}
}

func Test_sshCertPrincipalsModifier_Modify(t *testing.T) {
	type test struct {
		modifier sshCertPrincipalsModifier
		cert     *ssh.Certificate
		expected []string
	}
	tests := map[string](func() test){
		"ok": func() test {
			a := []string{"foo", "bar"}
			return test{
				modifier: sshCertPrincipalsModifier(a),
				cert:     new(ssh.Certificate),
				expected: a,
			}
		},
	}
	for name, run := range tests {
		t.Run(name, func(t *testing.T) {
			tc := run()
			if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
				assert.Equals(t, tc.cert.ValidPrincipals, tc.expected)
			}
		})
	}
}

func Test_sshCertKeyIDModifier_Modify(t *testing.T) {
	type test struct {
		modifier sshCertKeyIDModifier
		cert     *ssh.Certificate
		expected string
	}
	tests := map[string](func() test){
		"ok": func() test {
			a := "foo"
			return test{
				modifier: sshCertKeyIDModifier(a),
				cert:     new(ssh.Certificate),
				expected: a,
			}
		},
	}
	for name, run := range tests {
		t.Run(name, func(t *testing.T) {
			tc := run()
			if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
				assert.Equals(t, tc.cert.KeyId, tc.expected)
			}
		})
	}
}

func Test_sshCertTypeModifier_Modify(t *testing.T) {
	type test struct {
		modifier sshCertTypeModifier
		cert     *ssh.Certificate
		expected uint32
	}
	tests := map[string](func() test){
		"ok/user": func() test {
			return test{
				modifier: sshCertTypeModifier("user"),
				cert:     new(ssh.Certificate),
				expected: ssh.UserCert,
			}
		},
		"ok/host": func() test {
			return test{
				modifier: sshCertTypeModifier("host"),
				cert:     new(ssh.Certificate),
				expected: ssh.HostCert,
			}
		},
		"ok/default": func() test {
			return test{
				modifier: sshCertTypeModifier("foo"),
				cert:     new(ssh.Certificate),
				expected: 0,
			}
		},
	}
	for name, run := range tests {
		t.Run(name, func(t *testing.T) {
			tc := run()
			if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
				assert.Equals(t, tc.cert.CertType, uint32(tc.expected))
			}
		})
	}
}

func Test_sshCertValidAfterModifier_Modify(t *testing.T) {
	type test struct {
		modifier sshCertValidAfterModifier
		cert     *ssh.Certificate
		expected uint64
	}
	tests := map[string](func() test){
		"ok": func() test {
			return test{
				modifier: sshCertValidAfterModifier(15),
				cert:     new(ssh.Certificate),
				expected: 15,
			}
		},
	}
	for name, run := range tests {
		t.Run(name, func(t *testing.T) {
			tc := run()
			if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
				assert.Equals(t, tc.cert.ValidAfter, tc.expected)
			}
		})
	}
}

func Test_sshCertDefaultsModifier_Modify(t *testing.T) {
	type test struct {
		modifier sshCertDefaultsModifier
		cert     *ssh.Certificate
		valid    func(*ssh.Certificate)
	}
	tests := map[string](func() test){
		"ok/changes": func() test {
			n := time.Now()
			va := NewTimeDuration(n.Add(1 * time.Minute))
			vb := NewTimeDuration(n.Add(5 * time.Minute))
			so := SignSSHOptions{
				Principals:  []string{"foo", "bar"},
				CertType:    "host",
				ValidAfter:  va,
				ValidBefore: vb,
			}
			return test{
				modifier: sshCertDefaultsModifier(so),
				cert:     new(ssh.Certificate),
				valid: func(cert *ssh.Certificate) {
					assert.Equals(t, cert.ValidPrincipals, so.Principals)
					assert.Equals(t, cert.CertType, uint32(ssh.HostCert))
					assert.Equals(t, cert.ValidAfter, uint64(so.ValidAfter.RelativeTime(time.Now()).Unix()))
					assert.Equals(t, cert.ValidBefore, uint64(so.ValidBefore.RelativeTime(time.Now()).Unix()))
				},
			}
		},
		"ok/no-changes": func() test {
			n := time.Now()
			so := SignSSHOptions{
				Principals:  []string{"foo", "bar"},
				CertType:    "host",
				ValidAfter:  NewTimeDuration(n.Add(15 * time.Minute)),
				ValidBefore: NewTimeDuration(n.Add(25 * time.Minute)),
			}
			return test{
				modifier: sshCertDefaultsModifier(so),
				cert: &ssh.Certificate{
					CertType:        uint32(ssh.UserCert),
					ValidPrincipals: []string{"zap", "zoop"},
					ValidAfter:      15,
					ValidBefore:     25,
				},
				valid: func(cert *ssh.Certificate) {
					assert.Equals(t, cert.ValidPrincipals, []string{"zap", "zoop"})
					assert.Equals(t, cert.CertType, uint32(ssh.UserCert))
					assert.Equals(t, cert.ValidAfter, uint64(15))
					assert.Equals(t, cert.ValidBefore, uint64(25))
				},
			}
		},
	}
	for name, run := range tests {
		t.Run(name, func(t *testing.T) {
			tc := run()
			if assert.Nil(t, tc.modifier.Modify(tc.cert, SignSSHOptions{})) {
				tc.valid(tc.cert)
			}
		})
	}
}

func Test_sshDefaultExtensionModifier_Modify(t *testing.T) {
	type test struct {
		modifier sshDefaultExtensionModifier
		cert     *ssh.Certificate
		valid    func(*ssh.Certificate)
		err      error
	}
	tests := map[string](func() test){
		"fail/unexpected-cert-type": func() test {
			cert := &ssh.Certificate{CertType: 3}
			return test{
				modifier: sshDefaultExtensionModifier{},
				cert:     cert,
				err:      errors.New("ssh certificate type has not been set or is invalid"),
			}
		},
		"ok/host": func() test {
			cert := &ssh.Certificate{CertType: ssh.HostCert}
			return test{
				modifier: sshDefaultExtensionModifier{},
				cert:     cert,
				valid: func(cert *ssh.Certificate) {
					assert.Len(t, 0, cert.Extensions)
				},
			}
		},
		"ok/user/extensions-exists": func() test {
			cert := &ssh.Certificate{CertType: ssh.UserCert, Permissions: ssh.Permissions{Extensions: map[string]string{
				"foo": "bar",
			}}}
			return test{
				modifier: sshDefaultExtensionModifier{},
				cert:     cert,
				valid: func(cert *ssh.Certificate) {
					val, ok := cert.Extensions["foo"]
					assert.True(t, ok)
					assert.Equals(t, val, "bar")

					val, ok = cert.Extensions["permit-X11-forwarding"]
					assert.True(t, ok)
					assert.Equals(t, val, "")

					val, ok = cert.Extensions["permit-agent-forwarding"]
					assert.True(t, ok)
					assert.Equals(t, val, "")

					val, ok = cert.Extensions["permit-port-forwarding"]
					assert.True(t, ok)
					assert.Equals(t, val, "")

					val, ok = cert.Extensions["permit-pty"]
					assert.True(t, ok)
					assert.Equals(t, val, "")

					val, ok = cert.Extensions["permit-user-rc"]
					assert.True(t, ok)
					assert.Equals(t, val, "")
				},
			}
		},
		"ok/user/no-extensions": func() test {
			return test{
				modifier: sshDefaultExtensionModifier{},
				cert:     &ssh.Certificate{CertType: ssh.UserCert},
				valid: func(cert *ssh.Certificate) {
					_, ok := cert.Extensions["foo"]
					assert.False(t, ok)

					val, ok := cert.Extensions["permit-X11-forwarding"]
					assert.True(t, ok)
					assert.Equals(t, val, "")

					val, ok = cert.Extensions["permit-agent-forwarding"]
					assert.True(t, ok)
					assert.Equals(t, val, "")

					val, ok = cert.Extensions["permit-port-forwarding"]
					assert.True(t, ok)
					assert.Equals(t, val, "")

					val, ok = cert.Extensions["permit-pty"]
					assert.True(t, ok)
					assert.Equals(t, val, "")

					val, ok = cert.Extensions["permit-user-rc"]
					assert.True(t, ok)
					assert.Equals(t, val, "")
				},
			}
		},
	}
	for name, run := range tests {
		t.Run(name, func(t *testing.T) {
			tc := run()
			if err := tc.modifier.Modify(tc.cert, SignSSHOptions{}); err != nil {
				if assert.NotNil(t, tc.err) {
					assert.HasPrefix(t, err.Error(), tc.err.Error())
				}
			} else {
				if assert.Nil(t, tc.err) {
					tc.valid(tc.cert)
				}
			}
		})
	}
}

func Test_sshCertDefaultValidator_Valid(t *testing.T) {
	pub, _, err := keyutil.GenerateDefaultKeyPair()
	assert.FatalError(t, err)
	sshPub, err := ssh.NewPublicKey(pub)
	assert.FatalError(t, err)
	v := sshCertDefaultValidator{}
	tests := []struct {
		name string
		cert *ssh.Certificate
		err  error
	}{
		{
			"fail/zero-nonce",
			&ssh.Certificate{},
			errors.New("ssh certificate nonce cannot be empty"),
		},
		{
			"fail/nil-key",
			&ssh.Certificate{Nonce: []byte("foo")},
			errors.New("ssh certificate key cannot be nil"),
		},
		{
			"fail/zero-serial",
			&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub},
			errors.New("ssh certificate serial cannot be 0"),
		},
		{
			"fail/unexpected-cert-type",
			// UserCert = 1, HostCert = 2
			&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, CertType: 3, Serial: 1},
			errors.New("ssh certificate has an unknown type: 3"),
		},
		{
			"fail/empty-cert-key-id",
			&ssh.Certificate{Nonce: []byte("foo"), Key: sshPub, Serial: 1, CertType: 1},
			errors.New("ssh certificate key id cannot be empty"),
		},
		{
			"fail/zero-validAfter",
			&ssh.Certificate{
				Nonce:           []byte("foo"),
				Key:             sshPub,
				Serial:          1,
				CertType:        1,
				KeyId:           "foo",
				ValidPrincipals: []string{"foo"},
				ValidAfter:      0,
			},
			errors.New("ssh certificate validAfter cannot be 0"),
		},
		{
			"fail/validBefore-past",
			&ssh.Certificate{
				Nonce:           []byte("foo"),
				Key:             sshPub,
				Serial:          1,
				CertType:        1,
				KeyId:           "foo",
				ValidPrincipals: []string{"foo"},
				ValidAfter:      uint64(time.Now().Add(-10 * time.Minute).Unix()),
				ValidBefore:     uint64(time.Now().Add(-5 * time.Minute).Unix()),
			},
			errors.New("ssh certificate validBefore cannot be in the past"),
		},
		{
			"fail/validAfter-after-validBefore",
			&ssh.Certificate{
				Nonce:           []byte("foo"),
				Key:             sshPub,
				Serial:          1,
				CertType:        1,
				KeyId:           "foo",
				ValidPrincipals: []string{"foo"},
				ValidAfter:      uint64(time.Now().Add(15 * time.Minute).Unix()),
				ValidBefore:     uint64(time.Now().Add(10 * time.Minute).Unix()),
			},
			errors.New("ssh certificate validBefore cannot be before validAfter"),
		},
		{
			"fail/nil-signature-key",
			&ssh.Certificate{
				Nonce:           []byte("foo"),
				Key:             sshPub,
				Serial:          1,
				CertType:        1,
				KeyId:           "foo",
				ValidPrincipals: []string{"foo"},
				ValidAfter:      uint64(time.Now().Unix()),
				ValidBefore:     uint64(time.Now().Add(10 * time.Minute).Unix()),
				Permissions: ssh.Permissions{
					Extensions: map[string]string{"foo": "bar"},
				},
			},
			errors.New("ssh certificate signature key cannot be nil"),
		},
		{
			"fail/nil-signature",
			&ssh.Certificate{
				Nonce:           []byte("foo"),
				Key:             sshPub,
				Serial:          1,
				CertType:        1,
				KeyId:           "foo",
				ValidPrincipals: []string{"foo"},
				ValidAfter:      uint64(time.Now().Unix()),
				ValidBefore:     uint64(time.Now().Add(10 * time.Minute).Unix()),
				Permissions: ssh.Permissions{
					Extensions: map[string]string{"foo": "bar"},
				},
				SignatureKey: sshPub,
			},
			errors.New("ssh certificate signature cannot be nil"),
		},
		{
			"ok/userCert",
			&ssh.Certificate{
				Nonce:           []byte("foo"),
				Key:             sshPub,
				Serial:          1,
				CertType:        1,
				KeyId:           "foo",
				ValidPrincipals: []string{"foo"},
				ValidAfter:      uint64(time.Now().Unix()),
				ValidBefore:     uint64(time.Now().Add(10 * time.Minute).Unix()),
				Permissions: ssh.Permissions{
					Extensions: map[string]string{"foo": "bar"},
				},
				SignatureKey: sshPub,
				Signature:    &ssh.Signature{},
			},
			nil,
		},
		{
			"ok/hostCert",
			&ssh.Certificate{
				Nonce:           []byte("foo"),
				Key:             sshPub,
				Serial:          1,
				CertType:        2,
				KeyId:           "foo",
				ValidPrincipals: []string{"foo"},
				ValidAfter:      uint64(time.Now().Unix()),
				ValidBefore:     uint64(time.Now().Add(10 * time.Minute).Unix()),
				SignatureKey:    sshPub,
				Signature:       &ssh.Signature{},
			},
			nil,
		},
		{
			"ok/emptyPrincipals",
			&ssh.Certificate{
				Nonce:           []byte("foo"),
				Key:             sshPub,
				Serial:          1,
				CertType:        1,
				KeyId:           "foo",
				ValidPrincipals: []string{},
				ValidAfter:      uint64(time.Now().Unix()),
				ValidBefore:     uint64(time.Now().Add(10 * time.Minute).Unix()),
				SignatureKey:    sshPub,
				Signature:       &ssh.Signature{},
			},
			nil,
		},
		{
			"ok/empty-extensions",
			&ssh.Certificate{
				Nonce:           []byte("foo"),
				Key:             sshPub,
				Serial:          1,
				CertType:        1,
				KeyId:           "foo",
				ValidPrincipals: []string{},
				ValidAfter:      uint64(time.Now().Unix()),
				ValidBefore:     uint64(time.Now().Add(10 * time.Minute).Unix()),
				SignatureKey:    sshPub,
				Signature:       &ssh.Signature{},
			},
			nil,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if err := v.Valid(tt.cert, SignSSHOptions{}); err != nil {
				if assert.NotNil(t, tt.err) {
					assert.HasPrefix(t, err.Error(), tt.err.Error())
				}
			} else {
				assert.Nil(t, tt.err)
			}
		})
	}
}

func Test_sshCertValidityValidator(t *testing.T) {
	p, err := generateX5C(nil)
	assert.FatalError(t, err)
	v := sshCertValidityValidator{p.claimer}
	n := now()
	tests := []struct {
		name string
		cert *ssh.Certificate
		opts SignSSHOptions
		err  error
	}{
		{
			"fail/validAfter-0",
			&ssh.Certificate{CertType: ssh.UserCert},
			SignSSHOptions{},
			errors.New("ssh certificate validAfter cannot be 0"),
		},
		{
			"fail/validBefore-in-past",
			&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(-time.Minute).Unix())},
			SignSSHOptions{},
			errors.New("ssh certificate validBefore cannot be in the past"),
		},
		{
			"fail/validBefore-before-validAfter",
			&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: uint64(now().Add(5 * time.Minute).Unix()), ValidBefore: uint64(now().Add(3 * time.Minute).Unix())},
			SignSSHOptions{},
			errors.New("ssh certificate validBefore cannot be before validAfter"),
		},
		{
			"fail/cert-type-not-set",
			&ssh.Certificate{ValidAfter: uint64(now().Unix()), ValidBefore: uint64(now().Add(10 * time.Minute).Unix())},
			SignSSHOptions{},
			errors.New("ssh certificate type has not been set"),
		},
		{
			"fail/unexpected-cert-type",
			&ssh.Certificate{
				CertType:    3,
				ValidAfter:  uint64(now().Unix()),
				ValidBefore: uint64(now().Add(10 * time.Minute).Unix()),
			},
			SignSSHOptions{},
			errors.New("unknown ssh certificate type 3"),
		},
		{
			"fail/duration<min",
			&ssh.Certificate{
				CertType:    1,
				ValidAfter:  uint64(n.Unix()),
				ValidBefore: uint64(n.Add(4 * time.Minute).Unix()),
			},
			SignSSHOptions{Backdate: time.Second},
			errors.New("requested duration of 4m0s is less than minimum accepted duration for selected provisioner of 5m0s"),
		},
		{
			"ok/duration-exactly-min",
			&ssh.Certificate{
				CertType:    1,
				ValidAfter:  uint64(n.Unix()),
				ValidBefore: uint64(n.Add(5 * time.Minute).Unix()),
			},
			SignSSHOptions{Backdate: time.Second},
			nil,
		},
		{
			"fail/duration>max",
			&ssh.Certificate{
				CertType:    1,
				ValidAfter:  uint64(n.Unix()),
				ValidBefore: uint64(n.Add(48 * time.Hour).Unix()),
			},
			SignSSHOptions{Backdate: time.Second},
			errors.New("requested duration of 48h0m0s is greater than maximum accepted duration for selected provisioner of 24h0m1s"),
		},
		{
			"ok/duration-exactly-max",
			&ssh.Certificate{
				CertType:    1,
				ValidAfter:  uint64(n.Unix()),
				ValidBefore: uint64(n.Add(24*time.Hour + time.Second).Unix()),
			},
			SignSSHOptions{Backdate: time.Second},
			nil,
		},
		{
			"ok",
			&ssh.Certificate{
				CertType:    1,
				ValidAfter:  uint64(now().Unix()),
				ValidBefore: uint64(now().Add(8 * time.Hour).Unix()),
			},
			SignSSHOptions{Backdate: time.Second},
			nil,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if err := v.Valid(tt.cert, tt.opts); err != nil {
				if assert.NotNil(t, tt.err) {
					assert.HasPrefix(t, err.Error(), tt.err.Error())
				}
			} else {
				assert.Nil(t, tt.err)
			}
		})
	}
}

func Test_sshValidityModifier(t *testing.T) {
	n, fn := mockNow()
	defer fn()

	p, err := generateX5C(nil)
	assert.FatalError(t, err)
	type test struct {
		svm   *sshLimitDuration
		cert  *ssh.Certificate
		valid func(*ssh.Certificate)
		err   error
	}
	tests := map[string]func() test{
		"fail/type-not-set": func() test {
			return test{
				svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)},
				cert: &ssh.Certificate{
					ValidAfter:  uint64(n.Unix()),
					ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
				},
				err: errors.New("ssh certificate type has not been set"),
			}
		},
		"fail/type-not-recognized": func() test {
			return test{
				svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(6 * time.Hour)},
				cert: &ssh.Certificate{
					CertType:    4,
					ValidAfter:  uint64(n.Unix()),
					ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
				},
				err: errors.New("ssh certificate has an unknown type: 4"),
			}
		},
		"fail/requested-validAfter-after-limit": func() test {
			return test{
				svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)},
				cert: &ssh.Certificate{
					CertType:    1,
					ValidAfter:  uint64(n.Add(2 * time.Hour).Unix()),
					ValidBefore: uint64(n.Add(8 * time.Hour).Unix()),
				},
				err: errors.Errorf("provisioning credential expiration ("),
			}
		},
		"fail/requested-validBefore-after-limit": func() test {
			return test{
				svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(1 * time.Hour)},
				cert: &ssh.Certificate{
					CertType:    1,
					ValidAfter:  uint64(n.Unix()),
					ValidBefore: uint64(n.Add(2 * time.Hour).Unix()),
				},
				err: errors.New("provisioning credential expiration ("),
			}
		},
		"ok/no-limit": func() test {
			va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
			return test{
				svm: &sshLimitDuration{Claimer: p.claimer},
				cert: &ssh.Certificate{
					CertType: 1,
				},
				valid: func(cert *ssh.Certificate) {
					assert.Equals(t, cert.ValidAfter, va)
					assert.Equals(t, cert.ValidBefore, vb)
				},
			}
		},
		"ok/defaults": func() test {
			va, vb := uint64(n.Unix()), uint64(n.Add(16*time.Hour).Unix())
			return test{
				svm: &sshLimitDuration{Claimer: p.claimer},
				cert: &ssh.Certificate{
					CertType: 1,
				},
				valid: func(cert *ssh.Certificate) {
					assert.Equals(t, cert.ValidAfter, va)
					assert.Equals(t, cert.ValidBefore, vb)
				},
			}
		},
		"ok/valid-requested-validBefore": func() test {
			va, vb := uint64(n.Unix()), uint64(n.Add(2*time.Hour).Unix())
			return test{
				svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)},
				cert: &ssh.Certificate{
					CertType:    1,
					ValidAfter:  va,
					ValidBefore: vb,
				},
				valid: func(cert *ssh.Certificate) {
					assert.Equals(t, cert.ValidAfter, va)
					assert.Equals(t, cert.ValidBefore, vb)
				},
			}
		},
		"ok/empty-requested-validBefore-limit-after-default": func() test {
			va := uint64(n.Unix())
			return test{
				svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(24 * time.Hour)},
				cert: &ssh.Certificate{
					CertType:   1,
					ValidAfter: va,
				},
				valid: func(cert *ssh.Certificate) {
					assert.Equals(t, cert.ValidAfter, va)
					assert.Equals(t, cert.ValidBefore, uint64(n.Add(16*time.Hour).Unix()))
				},
			}
		},
		"ok/empty-requested-validBefore-limit-before-default": func() test {
			va := uint64(n.Unix())
			return test{
				svm: &sshLimitDuration{Claimer: p.claimer, NotAfter: n.Add(3 * time.Hour)},
				cert: &ssh.Certificate{
					CertType:   1,
					ValidAfter: va,
				},
				valid: func(cert *ssh.Certificate) {
					assert.Equals(t, cert.ValidAfter, va)
					assert.Equals(t, cert.ValidBefore, uint64(n.Add(3*time.Hour).Unix()))
				},
			}
		},
	}
	for name, run := range tests {
		t.Run(name, func(t *testing.T) {
			tt := run()
			if err := tt.svm.Modify(tt.cert, SignSSHOptions{}); err != nil {
				if assert.NotNil(t, tt.err) {
					assert.HasPrefix(t, err.Error(), tt.err.Error())
				}
			} else {
				if assert.Nil(t, tt.err) {
					tt.valid(tt.cert)
				}
			}
		})
	}
}

func Test_sshDefaultDuration_Option(t *testing.T) {
	tm, fn := mockNow()
	defer fn()

	newClaimer := func(claims *Claims) *Claimer {
		c, err := NewClaimer(claims, globalProvisionerClaims)
		if err != nil {
			t.Fatal(err)
		}
		return c
	}
	unix := func(d time.Duration) uint64 {
		return uint64(tm.Add(d).Unix())
	}

	type fields struct {
		Claimer *Claimer
	}
	type args struct {
		o    SignSSHOptions
		cert *ssh.Certificate
	}
	tests := []struct {
		name    string
		fields  fields
		args    args
		want    *ssh.Certificate
		wantErr bool
	}{
		{"user", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.UserCert}},
			&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(0), ValidBefore: unix(16 * time.Hour)}, false},
		{"host", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.HostCert}},
			&ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(0), ValidBefore: unix(30 * 24 * time.Hour)}, false},
		{"user claim", fields{newClaimer(&Claims{DefaultUserSSHDur: &Duration{1 * time.Hour}})}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.UserCert}},
			&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(0), ValidBefore: unix(1 * time.Hour)}, false},
		{"host claim", fields{newClaimer(&Claims{DefaultHostSSHDur: &Duration{1 * time.Hour}})}, args{SignSSHOptions{}, &ssh.Certificate{CertType: ssh.HostCert}},
			&ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(0), ValidBefore: unix(1 * time.Hour)}, false},
		{"user backdate", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert}},
			&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(16 * time.Hour)}, false},
		{"host backdate", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert}},
			&ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(30 * 24 * time.Hour)}, false},
		{"user validAfter", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(1 * time.Hour)}},
			&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(time.Hour), ValidBefore: unix(17 * time.Hour)}, false},
		{"user validBefore", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.UserCert, ValidBefore: unix(1 * time.Hour)}},
			&ssh.Certificate{CertType: ssh.UserCert, ValidAfter: unix(-1 * time.Minute), ValidBefore: unix(time.Hour)}, false},
		{"host validAfter validBefore", fields{newClaimer(nil)}, args{SignSSHOptions{Backdate: 1 * time.Minute}, &ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}},
			&ssh.Certificate{CertType: ssh.HostCert, ValidAfter: unix(1 * time.Minute), ValidBefore: unix(2 * time.Minute)}, false},
		{"fail zero", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{}}, &ssh.Certificate{}, true},
		{"fail type", fields{newClaimer(nil)}, args{SignSSHOptions{}, &ssh.Certificate{CertType: 3}}, &ssh.Certificate{CertType: 3}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			m := &sshDefaultDuration{
				Claimer: tt.fields.Claimer,
			}
			if err := m.Modify(tt.args.cert, tt.args.o); (err != nil) != tt.wantErr {
				t.Errorf("sshDefaultDuration.Option() error = %v, wantErr %v", err, tt.wantErr)
			}
			if !reflect.DeepEqual(tt.args.cert, tt.want) {
				t.Errorf("sshDefaultDuration.Option() = %v, want %v", tt.args.cert, tt.want)
			}
		})
	}
}