package identity

import (
	"crypto"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io/ioutil"
	"net/http"
	"os"
	"path/filepath"
	"reflect"
	"testing"

	"github.com/smallstep/certificates/api"
	"go.step.sm/crypto/pemutil"
)

func TestLoadDefaultIdentity(t *testing.T) {
	oldFile := IdentityFile
	defer func() {
		IdentityFile = oldFile
	}()

	expected := &Identity{
		Type:        "mTLS",
		Certificate: "testdata/identity/identity.crt",
		Key:         "testdata/identity/identity_key",
	}
	tests := []struct {
		name    string
		prepare func()
		want    *Identity
		wantErr bool
	}{
		{"ok", func() { IdentityFile = "testdata/config/identity.json" }, expected, false},
		{"fail read", func() { IdentityFile = "testdata/config/missing.json" }, nil, true},
		{"fail unmarshal", func() { IdentityFile = "testdata/config/fail.json" }, nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			tt.prepare()
			got, err := LoadDefaultIdentity()
			if (err != nil) != tt.wantErr {
				t.Errorf("LoadDefaultIdentity() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("LoadDefaultIdentity() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestIdentity_Kind(t *testing.T) {
	type fields struct {
		Type string
	}
	tests := []struct {
		name   string
		fields fields
		want   Type
	}{
		{"disabled", fields{""}, Disabled},
		{"mutualTLS", fields{"mTLS"}, MutualTLS},
		{"tunnelTLS", fields{"tTLS"}, TunnelTLS},
		{"unknown", fields{"unknown"}, Type("unknown")},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			i := &Identity{
				Type: tt.fields.Type,
			}
			if got := i.Kind(); got != tt.want {
				t.Errorf("Identity.Kind() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestIdentity_Validate(t *testing.T) {
	type fields struct {
		Type        string
		Certificate string
		Key         string
		Host        string
		Root        string
	}
	tests := []struct {
		name    string
		fields  fields
		wantErr bool
	}{
		{"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, false},
		{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, false},
		{"ok disabled", fields{}, false},
		{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, true},
		{"fail certificate", fields{"mTLS", "", "testdata/identity/identity_key", "", ""}, true},
		{"fail key", fields{"mTLS", "testdata/identity/identity.crt", "", "", ""}, true},
		{"fail key", fields{"tTLS", "testdata/identity/identity.crt", "", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
		{"fail missing certificate", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, true},
		{"fail missing certificate", fields{"tTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
		{"fail missing key", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", ""}, true},
		{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, true},
		{"fail host", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "", "testdata/certs/root_ca.crt"}, true},
		{"fail root", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			i := &Identity{
				Type:        tt.fields.Type,
				Certificate: tt.fields.Certificate,
				Key:         tt.fields.Key,
				Host:        tt.fields.Host,
				Root:        tt.fields.Root,
			}
			if err := i.Validate(); (err != nil) != tt.wantErr {
				t.Errorf("Identity.Validate() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func TestIdentity_TLSCertificate(t *testing.T) {
	expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key")
	if err != nil {
		t.Fatal(err)
	}

	type fields struct {
		Type        string
		Certificate string
		Key         string
	}
	tests := []struct {
		name    string
		fields  fields
		want    tls.Certificate
		wantErr bool
	}{
		{"ok mTLS", fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
		{"ok tTLS", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, expected, false},
		{"ok disabled", fields{}, tls.Certificate{}, false},
		{"fail type", fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
		{"fail certificate", fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
		{"fail not after", fields{"mTLS", "testdata/identity/expired.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
		{"fail not before", fields{"mTLS", "testdata/identity/not_before.crt", "testdata/identity/identity_key"}, tls.Certificate{}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			i := &Identity{
				Type:        tt.fields.Type,
				Certificate: tt.fields.Certificate,
				Key:         tt.fields.Key,
			}
			got, err := i.TLSCertificate()
			if (err != nil) != tt.wantErr {
				t.Errorf("Identity.TLSCertificate() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("Identity.TLSCertificate() = %v, want %v", got, tt.want)
			}
		})
	}
}

func Test_fileExists(t *testing.T) {
	type args struct {
		filename string
	}
	tests := []struct {
		name    string
		args    args
		wantErr bool
	}{
		{"ok", args{"testdata/identity/identity.crt"}, false},
		{"missing", args{"testdata/identity/missing.crt"}, true},
		{"directory", args{"testdata/identity"}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if err := fileExists(tt.args.filename); (err != nil) != tt.wantErr {
				t.Errorf("fileExists() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func TestWriteDefaultIdentity(t *testing.T) {
	tmpDir, err := ioutil.TempDir(os.TempDir(), "go-tests")
	if err != nil {
		t.Fatal(err)
	}

	oldConfigDir := configDir
	oldIdentityDir := identityDir
	oldIdentityFile := IdentityFile
	defer func() {
		configDir = oldConfigDir
		identityDir = oldIdentityDir
		IdentityFile = oldIdentityFile
		os.RemoveAll(tmpDir)
	}()

	certs, err := pemutil.ReadCertificateBundle("testdata/identity/identity.crt")
	if err != nil {
		t.Fatal(err)
	}
	key, err := pemutil.Read("testdata/identity/identity_key")
	if err != nil {
		t.Fatal(err)
	}

	var certChain []api.Certificate
	for _, c := range certs {
		certChain = append(certChain, api.Certificate{Certificate: c})
	}

	configDir = filepath.Join(tmpDir, "config")
	identityDir = filepath.Join(tmpDir, "identity")
	IdentityFile = filepath.Join(tmpDir, "config", "identity.json")

	type args struct {
		certChain []api.Certificate
		key       crypto.PrivateKey
	}
	tests := []struct {
		name    string
		prepare func()
		args    args
		wantErr bool
	}{
		{"ok", func() {}, args{certChain, key}, false},
		{"fail mkdir config", func() {
			configDir = filepath.Join(tmpDir, "identity", "identity.crt")
			identityDir = filepath.Join(tmpDir, "identity")
		}, args{certChain, key}, true},
		{"fail mkdir identity", func() {
			configDir = filepath.Join(tmpDir, "config")
			identityDir = filepath.Join(tmpDir, "identity", "identity.crt")
		}, args{certChain, key}, true},
		{"fail certificate", func() {
			configDir = filepath.Join(tmpDir, "config")
			identityDir = filepath.Join(tmpDir, "bad-dir")
			os.MkdirAll(identityDir, 0600)
		}, args{certChain, key}, true},
		{"fail key", func() {
			configDir = filepath.Join(tmpDir, "config")
			identityDir = filepath.Join(tmpDir, "identity")
		}, args{certChain, "badKey"}, true},
		{"fail write identity", func() {
			configDir = filepath.Join(tmpDir, "bad-dir")
			identityDir = filepath.Join(tmpDir, "identity")
			IdentityFile = filepath.Join(configDir, "identity.json")
			os.MkdirAll(configDir, 0600)
		}, args{certChain, key}, true},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			tt.prepare()
			if err := WriteDefaultIdentity(tt.args.certChain, tt.args.key); (err != nil) != tt.wantErr {
				t.Errorf("WriteDefaultIdentity() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func TestIdentity_GetClientCertificateFunc(t *testing.T) {
	expected, err := tls.LoadX509KeyPair("testdata/identity/identity.crt", "testdata/identity/identity_key")
	if err != nil {
		t.Fatal(err)
	}

	type fields struct {
		Type        string
		Certificate string
		Key         string
		Host        string
		Root        string
	}
	tests := []struct {
		name    string
		fields  fields
		want    *tls.Certificate
		wantErr bool
	}{
		{"ok mTLS", fields{"mtls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "", ""}, &expected, false},
		{"ok tTLS", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, &expected, false},
		{"fail missing cert", fields{"mTLS", "testdata/identity/missing.crt", "testdata/identity/identity_key", "", ""}, nil, true},
		{"fail missing key", fields{"tTLS", "testdata/identity/identity.crt", "testdata/identity/missing_key", "tunnel:443", "testdata/certs/root_ca.crt"}, nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			i := &Identity{
				Type:        tt.fields.Type,
				Certificate: tt.fields.Certificate,
				Key:         tt.fields.Key,
				Host:        tt.fields.Host,
				Root:        tt.fields.Root,
			}
			fn := i.GetClientCertificateFunc()
			got, err := fn(&tls.CertificateRequestInfo{})
			if (err != nil) != tt.wantErr {
				t.Errorf("Identity.GetClientCertificateFunc() = %v, wantErr %v", err, tt.wantErr)
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("Identity.GetClientCertificateFunc() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestIdentity_GetCertPool(t *testing.T) {
	type fields struct {
		Type        string
		Certificate string
		Key         string
		Host        string
		Root        string
	}
	tests := []struct {
		name         string
		fields       fields
		wantSubjects [][]byte
		wantErr      bool
	}{
		{"ok", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/root_ca.crt"}, [][]byte{[]byte("0\x1c1\x1a0\x18\x06\x03U\x04\x03\x13\x11Smallstep Root CA")}, false},
		{"ok nil", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", ""}, nil, false},
		{"fail missing", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/certs/missing.crt"}, nil, true},
		{"fail no cert", fields{"ttls", "testdata/identity/identity.crt", "testdata/identity/identity_key", "tunnel:443", "testdata/secrets/root_ca_key"}, nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			i := &Identity{
				Type:        tt.fields.Type,
				Certificate: tt.fields.Certificate,
				Key:         tt.fields.Key,
				Host:        tt.fields.Host,
				Root:        tt.fields.Root,
			}
			got, err := i.GetCertPool()
			if (err != nil) != tt.wantErr {
				t.Errorf("Identity.GetCertPool() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if got != nil {
				subjects := got.Subjects()
				if !reflect.DeepEqual(subjects, tt.wantSubjects) {
					t.Errorf("Identity.GetCertPool() = %x, want %x", subjects, tt.wantSubjects)
				}
			}

		})
	}
}

type renewer struct {
	pool *x509.CertPool
	sign *api.SignResponse
	err  error
}

func (r *renewer) GetRootCAs() *x509.CertPool {
	return r.pool
}

func (r *renewer) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
	return r.sign, r.err
}

func TestIdentity_Renew(t *testing.T) {
	tmpDir, err := ioutil.TempDir(os.TempDir(), "go-tests")
	if err != nil {
		t.Fatal(err)
	}

	oldIdentityDir := identityDir
	identityDir = "testdata/identity"
	defer func() {
		identityDir = oldIdentityDir
		os.RemoveAll(tmpDir)
	}()

	certs, err := pemutil.ReadCertificateBundle("testdata/identity/identity.crt")
	if err != nil {
		t.Fatal(err)
	}

	ok := &renewer{
		sign: &api.SignResponse{
			ServerPEM: api.Certificate{Certificate: certs[0]},
			CaPEM:     api.Certificate{Certificate: certs[1]},
			CertChainPEM: []api.Certificate{
				{Certificate: certs[0]},
				{Certificate: certs[1]},
			},
		},
	}

	okOld := &renewer{
		sign: &api.SignResponse{
			ServerPEM: api.Certificate{Certificate: certs[0]},
			CaPEM:     api.Certificate{Certificate: certs[1]},
		},
	}

	fail := &renewer{
		err: fmt.Errorf("an error"),
	}

	type fields struct {
		Type        string
		Certificate string
		Key         string
	}
	type args struct {
		client Renewer
	}
	tests := []struct {
		name    string
		prepare func()
		fields  fields
		args    args
		wantErr bool
	}{
		{"ok", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, false},
		{"ok old", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{okOld}, false},
		{"ok disabled", func() {}, fields{}, args{nil}, false},
		{"fail type", func() {}, fields{"foo", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
		{"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true},
		{"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true},
		{"fail write identity", func() {
			identityDir = filepath.Join(tmpDir, "bad-dir")
			os.MkdirAll(identityDir, 0600)
		}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			tt.prepare()
			i := &Identity{
				Type:        tt.fields.Type,
				Certificate: tt.fields.Certificate,
				Key:         tt.fields.Key,
			}
			if err := i.Renew(tt.args.client); (err != nil) != tt.wantErr {
				t.Errorf("Identity.Renew() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}