package sshagentkms

import (
	"bytes"
	"context"
	"crypto"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/x509"
	"encoding/pem"
	"net"
	"os"
	"os/exec"
	"path/filepath"
	"reflect"
	"strconv"
	"strings"
	"testing"

	"github.com/smallstep/certificates/kms/apiv1"

	"golang.org/x/crypto/ssh"
	"golang.org/x/crypto/ssh/agent"

	"go.step.sm/crypto/pemutil"
)

// Some helpers with inspiration from crypto/ssh/agent/client_test.go

// startOpenSSHAgent executes ssh-agent, and returns an Agent interface to it.
func startOpenSSHAgent(t *testing.T) (client agent.Agent, socket string, cleanup func()) {
	/* Always test with OpenSSHAgent
	if testing.Short() {
		// ssh-agent is not always available, and the key
		// types supported vary by platform.
		t.Skip("skipping test due to -short")
	}
	*/

	bin, err := exec.LookPath("ssh-agent")
	if err != nil {
		t.Skip("could not find ssh-agent")
	}

	cmd := exec.Command(bin, "-s")
	cmd.Env = []string{} // Do not let the user's environment influence ssh-agent behavior.
	cmd.Stderr = new(bytes.Buffer)
	out, err := cmd.Output()
	if err != nil {
		t.Fatalf("%s failed: %v\n%s", strings.Join(cmd.Args, " "), err, cmd.Stderr)
	}

	// Output looks like:
	//
	//	SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK;
	//	SSH_AGENT_PID=15542; export SSH_AGENT_PID;
	//	echo Agent pid 15542;

	fields := bytes.Split(out, []byte(";"))
	line := bytes.SplitN(fields[0], []byte("="), 2)
	line[0] = bytes.TrimLeft(line[0], "\n")
	if string(line[0]) != "SSH_AUTH_SOCK" {
		t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0])
	}
	socket = string(line[1])

	line = bytes.SplitN(fields[2], []byte("="), 2)
	line[0] = bytes.TrimLeft(line[0], "\n")
	if string(line[0]) != "SSH_AGENT_PID" {
		t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2])
	}
	pidStr := line[1]
	pid, err := strconv.Atoi(string(pidStr))
	if err != nil {
		t.Fatalf("Atoi(%q): %v", pidStr, err)
	}

	conn, err := net.Dial("unix", string(socket))
	if err != nil {
		t.Fatalf("net.Dial: %v", err)
	}

	ac := agent.NewClient(conn)
	return ac, socket, func() {
		proc, _ := os.FindProcess(pid)
		if proc != nil {
			proc.Kill()
		}
		conn.Close()
		os.RemoveAll(filepath.Dir(socket))
	}
}

func startAgent(t *testing.T, sshagent agent.Agent) (client agent.Agent, cleanup func()) {
	c1, c2, err := netPipe()
	if err != nil {
		t.Fatalf("netPipe: %v", err)
	}
	go agent.ServeAgent(sshagent, c2)

	return agent.NewClient(c1), func() {
		c1.Close()
		c2.Close()
	}
}

// startKeyringAgent uses Keyring to simulate a ssh-agent Server and returns a client.
func startKeyringAgent(t *testing.T) (client agent.Agent, cleanup func()) {
	return startAgent(t, agent.NewKeyring())
}

type startTestAgentFunc func(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent)

func startTestOpenSSHAgent(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent) {
	sshagent, _, cleanup := startOpenSSHAgent(t)
	for _, keyToAdd := range keysToAdd {
		err := sshagent.Add(keyToAdd)
		if err != nil {
			t.Fatalf("sshagent.add: %v", err)
		}
	}
	t.Cleanup(cleanup)

	//testAgentInterface(t, sshagent, key, cert, lifetimeSecs)
	return sshagent
}

func startTestKeyringAgent(t *testing.T, keysToAdd ...agent.AddedKey) (sshagent agent.Agent) {
	sshagent, cleanup := startKeyringAgent(t)
	for _, keyToAdd := range keysToAdd {
		err := sshagent.Add(keyToAdd)
		if err != nil {
			t.Fatalf("sshagent.add: %v", err)
		}
	}
	t.Cleanup(cleanup)

	//testAgentInterface(t, agent, key, cert, lifetimeSecs)
	return sshagent
}

// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
// therefore is buffered (net.Pipe deadlocks if both sides start with
// a write.)
func netPipe() (net.Conn, net.Conn, error) {
	listener, err := netListener()
	if err != nil {
		return nil, nil, err
	}
	defer listener.Close()
	c1, err := net.Dial("tcp", listener.Addr().String())
	if err != nil {
		return nil, nil, err
	}

	c2, err := listener.Accept()
	if err != nil {
		c1.Close()
		return nil, nil, err
	}

	return c1, c2, nil
}

// netListener creates a localhost network listener.
func netListener() (net.Listener, error) {
	listener, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		listener, err = net.Listen("tcp", "[::1]:0")
		if err != nil {
			return nil, err
		}
	}
	return listener, nil
}

func TestNew(t *testing.T) {
	comment := "Key from OpenSSHAgent"
	// Ensure we don't "inherit" any SSH_AUTH_SOCK
	os.Unsetenv("SSH_AUTH_SOCK")

	sshagent, socket, cleanup := startOpenSSHAgent(t)

	os.Setenv("SSH_AUTH_SOCK", socket)
	t.Cleanup(func() {
		os.Unsetenv("SSH_AUTH_SOCK")
		cleanup()
	})

	// Test that we can't find any signers in the agent before we have loaded them
	t.Run("No keys with OpenSSHAgent", func(t *testing.T) {
		kms, err := New(context.Background(), apiv1.Options{})
		if kms == nil || err != nil {
			t.Errorf("New() = %v, %v", kms, err)
		}
		signer, err := kms.CreateSigner(&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment})
		if err == nil || signer != nil {
			t.Errorf("SSHAgentKMS.CreateSigner() error = \"%v\", signer = \"%v\"", err, signer)
		}
	})

	// Load ssh test fixtures
	b, err := os.ReadFile("testdata/ssh")
	if err != nil {
		t.Fatal(err)
	}
	privateKey, err := ssh.ParseRawPrivateKey(b)
	if err != nil {
		t.Fatal(err)
	}

	// And add that key to the agent
	err = sshagent.Add(agent.AddedKey{PrivateKey: privateKey, Comment: comment})
	if err != nil {
		t.Fatalf("sshagent.add: %v", err)
	}

	// And test that we can find it when it's loaded
	t.Run("Keys with OpenSSHAgent", func(t *testing.T) {
		kms, err := New(context.Background(), apiv1.Options{})
		if kms == nil || err != nil {
			t.Errorf("New() = %v, %v", kms, err)
		}
		signer, err := kms.CreateSigner(&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment})
		if err != nil || signer == nil {
			t.Errorf("SSHAgentKMS.CreateSigner() error = \"%v\", signer = \"%v\"", err, signer)
		}
	})
}

func TestNewFromAgent(t *testing.T) {
	type args struct {
		ctx  context.Context
		opts apiv1.Options
	}
	tests := []struct {
		name            string
		args            args
		sshagentstarter startTestAgentFunc
		wantErr         bool
	}{
		{"ok OpenSSHAgent", args{context.Background(), apiv1.Options{}}, startTestOpenSSHAgent, false},
		{"ok KeyringAgent", args{context.Background(), apiv1.Options{}}, startTestKeyringAgent, false},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := NewFromAgent(tt.args.ctx, tt.args.opts, tt.sshagentstarter(t))
			if (err != nil) != tt.wantErr {
				t.Errorf("NewFromAgent() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if got == nil {
				t.Errorf("NewFromAgent() = %v", got)
			}
		})
	}
}

func TestSSHAgentKMS_Close(t *testing.T) {
	tests := []struct {
		name    string
		wantErr bool
	}{
		{"ok", false},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			k := &SSHAgentKMS{}
			if err := k.Close(); (err != nil) != tt.wantErr {
				t.Errorf("SSHAgentKMS.Close() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func TestSSHAgentKMS_CreateSigner(t *testing.T) {
	pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	if err != nil {
		t.Fatal(err)
	}
	pemBlock, err := pemutil.Serialize(pk)
	if err != nil {
		t.Fatal(err)
	}
	pemBlockPassword, err := pemutil.Serialize(pk, pemutil.WithPassword([]byte("pass")))
	if err != nil {
		t.Fatal(err)
	}

	// Read and decode file using standard packages
	b, err := os.ReadFile("testdata/priv.pem")
	if err != nil {
		t.Fatal(err)
	}
	block, _ := pem.Decode(b)
	block.Bytes, err = x509.DecryptPEMBlock(block, []byte("pass")) //nolint
	if err != nil {
		t.Fatal(err)
	}
	pk2, err := x509.ParseECPrivateKey(block.Bytes)
	if err != nil {
		t.Fatal(err)
	}

	// Create a public PEM
	b, err = x509.MarshalPKIXPublicKey(pk.Public())
	if err != nil {
		t.Fatal(err)
	}
	pub := pem.EncodeToMemory(&pem.Block{
		Type:  "PUBLIC KEY",
		Bytes: b,
	})

	// Load ssh test fixtures
	sshPubKeyStr, err := os.ReadFile("testdata/ssh.pub")
	if err != nil {
		t.Fatal(err)
	}
	_, comment, _, _, err := ssh.ParseAuthorizedKey(sshPubKeyStr)
	if err != nil {
		t.Fatal(err)
	}
	b, err = os.ReadFile("testdata/ssh")
	if err != nil {
		t.Fatal(err)
	}
	privateKey, err := ssh.ParseRawPrivateKey(b)
	if err != nil {
		t.Fatal(err)
	}
	sshPrivateKey, err := ssh.NewSignerFromKey(privateKey)
	if err != nil {
		t.Fatal(err)
	}
	wrappedSSHPrivateKey := NewWrappedSignerFromSSHSigner(sshPrivateKey)

	type args struct {
		req *apiv1.CreateSignerRequest
	}
	tests := []struct {
		name    string
		args    args
		want    crypto.Signer
		wantErr bool
	}{
		{"signer", args{&apiv1.CreateSignerRequest{Signer: pk}}, pk, false},
		{"pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlock)}}, pk, false},
		{"pem password", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pem.EncodeToMemory(pemBlockPassword), Password: []byte("pass")}}, pk, false},
		{"file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("pass")}}, pk2, false},
		{"sshagent", args{&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:" + comment}}, wrappedSSHPrivateKey, false},
		{"sshagent Nonexistant", args{&apiv1.CreateSignerRequest{SigningKey: "sshagentkms:Nonexistant"}}, nil, true},
		{"fail", args{&apiv1.CreateSignerRequest{}}, nil, true},
		{"fail bad pem", args{&apiv1.CreateSignerRequest{SigningKeyPEM: []byte("bad pem")}}, nil, true},
		{"fail bad password", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/priv.pem", Password: []byte("bad-pass")}}, nil, true},
		{"fail not a signer", args{&apiv1.CreateSignerRequest{SigningKeyPEM: pub}}, nil, true},
		{"fail not a signer from file", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/pub.pem"}}, nil, true},
		{"fail missing", args{&apiv1.CreateSignerRequest{SigningKey: "testdata/missing"}}, nil, true},
	}
	starters := []struct {
		name    string
		starter startTestAgentFunc
	}{
		{"startTestOpenSSHAgent", startTestOpenSSHAgent},
		{"startTestKeyringAgent", startTestKeyringAgent},
	}
	for _, starter := range starters {
		k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t, agent.AddedKey{PrivateKey: privateKey, Comment: comment}))
		if err != nil {
			t.Fatal(err)
		}
		for _, tt := range tests {
			t.Run(starter.name+"/"+tt.name, func(t *testing.T) {
				got, err := k.CreateSigner(tt.args.req)
				if (err != nil) != tt.wantErr {
					t.Errorf("SSHAgentKMS.CreateSigner() error = %v, wantErr %v", err, tt.wantErr)
					return
				}
				// nolint:gocritic
				switch s := got.(type) {
				case *WrappedSSHSigner:
					gotPkS := s.Sshsigner.PublicKey().(*agent.Key).String() + "\n"
					wantPkS := string(sshPubKeyStr)
					if !reflect.DeepEqual(gotPkS, wantPkS) {
						t.Errorf("SSHAgentKMS.CreateSigner() = %T, want %T", gotPkS, wantPkS)
						t.Errorf("SSHAgentKMS.CreateSigner() = %v, want %v", gotPkS, wantPkS)
					}
				default:
					if !reflect.DeepEqual(got, tt.want) {
						t.Errorf("SSHAgentKMS.CreateSigner() = %T, want %T", got, tt.want)
						t.Errorf("SSHAgentKMS.CreateSigner() = %v, want %v", got, tt.want)
					}
				}
			})
		}
	}
}

/*
func restoreGenerateKey() func() {
	oldGenerateKey := generateKey
	return func() {
		generateKey = oldGenerateKey
	}
}
*/

/*
func TestSSHAgentKMS_CreateKey(t *testing.T) {
	fn := restoreGenerateKey()
	defer fn()

	p256, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
	if err != nil {
		t.Fatal(err)
	}
	rsa2048, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		t.Fatal(err)
	}
	edpub, edpriv, err := ed25519.GenerateKey(rand.Reader)
	if err != nil {
		t.Fatal(err)
	}

	type args struct {
		req *apiv1.CreateKeyRequest
	}
	type params struct {
		kty  string
		crv  string
		size int
	}
	tests := []struct {
		name        string
		args        args
		generateKey func() (interface{}, interface{}, error)
		want        *apiv1.CreateKeyResponse
		wantParams  params
		wantErr     bool
	}{
		{"p256", args{&apiv1.CreateKeyRequest{Name: "p256", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
			return p256.Public(), p256, nil
		}, &apiv1.CreateKeyResponse{Name: "p256", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
		{"rsa", args{&apiv1.CreateKeyRequest{Name: "rsa3072", SignatureAlgorithm: apiv1.SHA256WithRSA}}, func() (interface{}, interface{}, error) {
			return rsa2048.Public(), rsa2048, nil
		}, &apiv1.CreateKeyResponse{Name: "rsa3072", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 0}, false},
		{"rsa2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 2048}}, func() (interface{}, interface{}, error) {
			return rsa2048.Public(), rsa2048, nil
		}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
		{"rsaPSS2048", args{&apiv1.CreateKeyRequest{Name: "rsa2048", SignatureAlgorithm: apiv1.SHA256WithRSAPSS, Bits: 2048}}, func() (interface{}, interface{}, error) {
			return rsa2048.Public(), rsa2048, nil
		}, &apiv1.CreateKeyResponse{Name: "rsa2048", PublicKey: rsa2048.Public(), PrivateKey: rsa2048, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: rsa2048}}, params{"RSA", "", 2048}, false},
		{"ed25519", args{&apiv1.CreateKeyRequest{Name: "ed25519", SignatureAlgorithm: apiv1.PureEd25519}}, func() (interface{}, interface{}, error) {
			return edpub, edpriv, nil
		}, &apiv1.CreateKeyResponse{Name: "ed25519", PublicKey: edpub, PrivateKey: edpriv, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: edpriv}}, params{"OKP", "Ed25519", 0}, false},
		{"default", args{&apiv1.CreateKeyRequest{Name: "default"}}, func() (interface{}, interface{}, error) {
			return p256.Public(), p256, nil
		}, &apiv1.CreateKeyResponse{Name: "default", PublicKey: p256.Public(), PrivateKey: p256, CreateSignerRequest: apiv1.CreateSignerRequest{Signer: p256}}, params{"EC", "P-256", 0}, false},
		{"fail algorithm", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.SignatureAlgorithm(100)}}, func() (interface{}, interface{}, error) {
			return p256.Public(), p256, nil
		}, nil, params{}, true},
		{"fail generate key", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
			return nil, nil, fmt.Errorf("an error")
		}, nil, params{"EC", "P-256", 0}, true},
		{"fail no signer", args{&apiv1.CreateKeyRequest{Name: "fail", SignatureAlgorithm: apiv1.ECDSAWithSHA256}}, func() (interface{}, interface{}, error) {
			return 1, 2, nil
		}, nil, params{"EC", "P-256", 0}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			k := &SSHAgentKMS{}
			generateKey = func(kty, crv string, size int) (interface{}, interface{}, error) {
				if tt.wantParams.kty != kty {
					t.Errorf("GenerateKey() kty = %s, want %s", kty, tt.wantParams.kty)
				}
				if tt.wantParams.crv != crv {
					t.Errorf("GenerateKey() crv = %s, want %s", crv, tt.wantParams.crv)
				}
				if tt.wantParams.size != size {
					t.Errorf("GenerateKey() size = %d, want %d", size, tt.wantParams.size)
				}
				return tt.generateKey()
			}

			got, err := k.CreateKey(tt.args.req)
			if (err != nil) != tt.wantErr {
				t.Errorf("SSHAgentKMS.CreateKey() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("SSHAgentKMS.CreateKey() = %v, want %v", got, tt.want)
			}
		})
	}
}
*/

func TestSSHAgentKMS_GetPublicKey(t *testing.T) {
	b, err := os.ReadFile("testdata/pub.pem")
	if err != nil {
		t.Fatal(err)
	}
	block, _ := pem.Decode(b)
	pub, err := x509.ParsePKIXPublicKey(block.Bytes)
	if err != nil {
		t.Fatal(err)
	}

	// Load ssh test fixtures
	b, err = os.ReadFile("testdata/ssh.pub")
	if err != nil {
		t.Fatal(err)
	}
	sshPubKey, comment, _, _, err := ssh.ParseAuthorizedKey(b)
	if err != nil {
		t.Fatal(err)
	}
	b, err = os.ReadFile("testdata/ssh")
	if err != nil {
		t.Fatal(err)
	}
	// crypto.PrivateKey
	sshPrivateKey, err := ssh.ParseRawPrivateKey(b)
	if err != nil {
		t.Fatal(err)
	}

	type args struct {
		req *apiv1.GetPublicKeyRequest
	}
	tests := []struct {
		name    string
		args    args
		want    crypto.PublicKey
		wantErr bool
	}{
		{"key", args{&apiv1.GetPublicKeyRequest{Name: "testdata/pub.pem"}}, pub, false},
		{"cert", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.crt"}}, pub, false},
		{"sshagent", args{&apiv1.GetPublicKeyRequest{Name: "sshagentkms:" + comment}}, sshPubKey, false},
		{"sshagent Nonexistant", args{&apiv1.GetPublicKeyRequest{Name: "sshagentkms:Nonexistant"}}, nil, true},
		{"fail not exists", args{&apiv1.GetPublicKeyRequest{Name: "testdata/missing"}}, nil, true},
		{"fail type", args{&apiv1.GetPublicKeyRequest{Name: "testdata/cert.key"}}, nil, true},
	}
	starters := []struct {
		name    string
		starter startTestAgentFunc
	}{
		{"startTestOpenSSHAgent", startTestOpenSSHAgent},
		{"startTestKeyringAgent", startTestKeyringAgent},
	}
	for _, starter := range starters {
		k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t, agent.AddedKey{PrivateKey: sshPrivateKey, Comment: comment}))
		if err != nil {
			t.Fatal(err)
		}
		for _, tt := range tests {
			t.Run(starter.name+"/"+tt.name, func(t *testing.T) {
				got, err := k.GetPublicKey(tt.args.req)
				if (err != nil) != tt.wantErr {
					t.Errorf("SSHAgentKMS.GetPublicKey() error = %v, wantErr %v", err, tt.wantErr)
					return
				}
				// nolint:gocritic
				switch tt.want.(type) {
				case ssh.PublicKey:
					// If we want a ssh.PublicKey, protote got to a
					got, err = ssh.NewPublicKey(got)
					if err != nil {
						t.Fatal(err)
					}
				}
				if !reflect.DeepEqual(got, tt.want) {
					t.Errorf("SSHAgentKMS.GetPublicKey() = %T, want %T", got, tt.want)
					t.Errorf("SSHAgentKMS.GetPublicKey() = %v, want %v", got, tt.want)
				}
			})
		}
	}
}

func TestSSHAgentKMS_CreateKey(t *testing.T) {
	starters := []struct {
		name    string
		starter startTestAgentFunc
	}{
		{"startTestOpenSSHAgent", startTestOpenSSHAgent},
		{"startTestKeyringAgent", startTestKeyringAgent},
	}
	for _, starter := range starters {
		k, err := NewFromAgent(context.Background(), apiv1.Options{}, starter.starter(t))
		if err != nil {
			t.Fatal(err)
		}
		t.Run(starter.name+"/CreateKey", func(t *testing.T) {
			got, err := k.CreateKey(&apiv1.CreateKeyRequest{
				Name:               "sshagentkms:0",
				SignatureAlgorithm: apiv1.ECDSAWithSHA256,
			})
			if got != nil {
				t.Error("SSHAgentKMS.CreateKey() shoudn't return a value")
			}
			if err == nil {
				t.Error("SSHAgentKMS.CreateKey() didn't return a value")
			}
		})
	}
}