package cloudkms

import (
	"context"
	"crypto"
	"crypto/rand"
	"crypto/x509"
	"fmt"
	"io"
	"os"
	"reflect"
	"testing"

	gax "github.com/googleapis/gax-go/v2"
	"go.step.sm/crypto/pemutil"
	kmspb "google.golang.org/genproto/googleapis/cloud/kms/v1"
)

func Test_newSigner(t *testing.T) {
	pemBytes, err := os.ReadFile("testdata/pub.pem")
	if err != nil {
		t.Fatal(err)
	}
	pk, err := pemutil.ParseKey(pemBytes)
	if err != nil {
		t.Fatal(err)
	}

	type args struct {
		c          KeyManagementClient
		signingKey string
	}
	tests := []struct {
		name    string
		args    args
		want    *Signer
		wantErr bool
	}{
		{"ok", args{&MockClient{
			getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
				return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
			},
		}, "signingKey"}, &Signer{client: &MockClient{}, signingKey: "signingKey", publicKey: pk}, false},
		{"fail get public key", args{&MockClient{
			getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
				return nil, fmt.Errorf("an error")
			},
		}, "signingKey"}, nil, true},
		{"fail parse pem", args{&MockClient{
			getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
				return &kmspb.PublicKey{Pem: string("bad pem")}, nil
			},
		}, "signingKey"}, nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := NewSigner(tt.args.c, tt.args.signingKey)
			if (err != nil) != tt.wantErr {
				t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if got != nil {
				got.client = &MockClient{}
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("NewSigner() = %v, want %v", got, tt.want)
			}
		})
	}
}

func Test_signer_Public(t *testing.T) {
	pemBytes, err := os.ReadFile("testdata/pub.pem")
	if err != nil {
		t.Fatal(err)
	}
	pk, err := pemutil.ParseKey(pemBytes)
	if err != nil {
		t.Fatal(err)
	}

	type fields struct {
		client     KeyManagementClient
		signingKey string
		publicKey  crypto.PublicKey
	}
	tests := []struct {
		name   string
		fields fields
		want   crypto.PublicKey
	}{
		{"ok", fields{&MockClient{}, "signingKey", pk}, pk},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			s := &Signer{
				client:     tt.fields.client,
				signingKey: tt.fields.signingKey,
				publicKey:  tt.fields.publicKey,
			}
			if got := s.Public(); !reflect.DeepEqual(got, tt.want) {
				t.Errorf("signer.Public() = %v, want %v", got, tt.want)
			}
		})
	}
}

func Test_signer_Sign(t *testing.T) {
	keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
	okClient := &MockClient{
		asymmetricSign: func(_ context.Context, _ *kmspb.AsymmetricSignRequest, _ ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error) {
			return &kmspb.AsymmetricSignResponse{Signature: []byte("ok signature")}, nil
		},
	}
	failClient := &MockClient{
		asymmetricSign: func(_ context.Context, _ *kmspb.AsymmetricSignRequest, _ ...gax.CallOption) (*kmspb.AsymmetricSignResponse, error) {
			return nil, fmt.Errorf("an error")
		},
	}

	type fields struct {
		client     KeyManagementClient
		signingKey string
	}
	type args struct {
		rand   io.Reader
		digest []byte
		opts   crypto.SignerOpts
	}
	tests := []struct {
		name    string
		fields  fields
		args    args
		want    []byte
		wantErr bool
	}{
		{"ok sha256", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA256}, []byte("ok signature"), false},
		{"ok sha384", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA384}, []byte("ok signature"), false},
		{"ok sha512", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA512}, []byte("ok signature"), false},
		{"fail MD5", fields{okClient, keyName}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true},
		{"fail asymmetric sign", fields{failClient, keyName}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			s := &Signer{
				client:     tt.fields.client,
				signingKey: tt.fields.signingKey,
			}
			got, err := s.Sign(tt.args.rand, tt.args.digest, tt.args.opts)
			if (err != nil) != tt.wantErr {
				t.Errorf("signer.Sign() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("signer.Sign() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestSigner_SignatureAlgorithm(t *testing.T) {
	pemBytes, err := os.ReadFile("testdata/pub.pem")
	if err != nil {
		t.Fatal(err)
	}

	client := &MockClient{
		getPublicKey: func(_ context.Context, req *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
			var algorithm kmspb.CryptoKeyVersion_CryptoKeyVersionAlgorithm
			switch req.Name {
			case "ECDSA-SHA256":
				algorithm = kmspb.CryptoKeyVersion_EC_SIGN_P256_SHA256
			case "ECDSA-SHA384":
				algorithm = kmspb.CryptoKeyVersion_EC_SIGN_P384_SHA384
			case "SHA256-RSA-2048":
				algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_2048_SHA256
			case "SHA256-RSA-3072":
				algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_3072_SHA256
			case "SHA256-RSA-4096":
				algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA256
			case "SHA512-RSA-4096":
				algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PKCS1_4096_SHA512
			case "SHA256-RSAPSS-2048":
				algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_2048_SHA256
			case "SHA256-RSAPSS-3072":
				algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_3072_SHA256
			case "SHA256-RSAPSS-4096":
				algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA256
			case "SHA512-RSAPSS-4096":
				algorithm = kmspb.CryptoKeyVersion_RSA_SIGN_PSS_4096_SHA512
			}
			return &kmspb.PublicKey{
				Pem:       string(pemBytes),
				Algorithm: algorithm,
			}, nil
		},
	}

	if err != nil {
		t.Fatal(err)
	}

	type fields struct {
		client     KeyManagementClient
		signingKey string
	}
	tests := []struct {
		name   string
		fields fields
		want   x509.SignatureAlgorithm
	}{
		{"ECDSA-SHA256", fields{client, "ECDSA-SHA256"}, x509.ECDSAWithSHA256},
		{"ECDSA-SHA384", fields{client, "ECDSA-SHA384"}, x509.ECDSAWithSHA384},
		{"SHA256-RSA-2048", fields{client, "SHA256-RSA-2048"}, x509.SHA256WithRSA},
		{"SHA256-RSA-3072", fields{client, "SHA256-RSA-3072"}, x509.SHA256WithRSA},
		{"SHA256-RSA-4096", fields{client, "SHA256-RSA-4096"}, x509.SHA256WithRSA},
		{"SHA512-RSA-4096", fields{client, "SHA512-RSA-4096"}, x509.SHA512WithRSA},
		{"SHA256-RSAPSS-2048", fields{client, "SHA256-RSAPSS-2048"}, x509.SHA256WithRSAPSS},
		{"SHA256-RSAPSS-3072", fields{client, "SHA256-RSAPSS-3072"}, x509.SHA256WithRSAPSS},
		{"SHA256-RSAPSS-4096", fields{client, "SHA256-RSAPSS-4096"}, x509.SHA256WithRSAPSS},
		{"SHA512-RSAPSS-4096", fields{client, "SHA512-RSAPSS-4096"}, x509.SHA512WithRSAPSS},
		{"unknown", fields{client, "UNKNOWN"}, x509.UnknownSignatureAlgorithm},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			signer, err := NewSigner(tt.fields.client, tt.fields.signingKey)
			if err != nil {
				t.Errorf("NewSigner() error = %v", err)
			}
			if got := signer.SignatureAlgorithm(); !reflect.DeepEqual(got, tt.want) {
				t.Errorf("Signer.SignatureAlgorithm() = %v, want %v", got, tt.want)
			}
		})
	}
}