package provisioner

import (
	"context"
	"crypto/x509"
	"errors"
	"fmt"
	"net/http"
	"testing"
	"time"

	"go.step.sm/crypto/jose"

	"github.com/smallstep/assert"
	"github.com/smallstep/certificates/api/render"
)

func TestK8sSA_Getters(t *testing.T) {
	p, err := generateK8sSA(nil)
	assert.FatalError(t, err)
	id := "k8ssa/" + p.Name
	if got := p.GetID(); got != id {
		t.Errorf("K8sSA.GetID() = %v, want %v", got, id)
	}
	if got := p.GetName(); got != p.Name {
		t.Errorf("K8sSA.GetName() = %v, want %v", got, p.Name)
	}
	if got := p.GetType(); got != TypeK8sSA {
		t.Errorf("K8sSA.GetType() = %v, want %v", got, TypeK8sSA)
	}
	kid, key, ok := p.GetEncryptedKey()
	if kid != "" || key != "" || ok == true {
		t.Errorf("K8sSA.GetEncryptedKey() = (%v, %v, %v), want (%v, %v, %v)",
			kid, key, ok, "", "", false)
	}
}

func TestK8sSA_authorizeToken(t *testing.T) {
	type test struct {
		p     *K8sSA
		token string
		err   error
		code  int
	}
	tests := map[string]func(*testing.T) test{
		"fail/bad-token": func(t *testing.T) test {
			p, err := generateK8sSA(nil)
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: "foo",
				code:  http.StatusUnauthorized,
				err:   errors.New("k8ssa.authorizeToken; error parsing k8sSA token"),
			}
		},
		"fail/not-implemented": func(t *testing.T) test {
			jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
			assert.FatalError(t, err)
			p, err := generateK8sSA(nil)
			assert.FatalError(t, err)
			tok, err := generateToken("", p.Name, testAudiences.Sign[0], "",
				[]string{"test.smallstep.com"}, time.Now(), jwk)
			p.pubKeys = nil
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: tok,
				err:   errors.New("k8ssa.authorizeToken; k8sSA TokenReview API integration not implemented"),
				code:  http.StatusUnauthorized,
			}
		},
		"fail/error-validating-token": func(t *testing.T) test {
			jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
			assert.FatalError(t, err)
			p, err := generateK8sSA(nil)
			assert.FatalError(t, err)
			tok, err := generateToken("", p.Name, testAudiences.Sign[0], "",
				[]string{"test.smallstep.com"}, time.Now(), jwk)
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: tok,
				err:   errors.New("k8ssa.authorizeToken; error validating k8sSA token and extracting claims"),
				code:  http.StatusUnauthorized,
			}
		},
		"fail/invalid-issuer": func(t *testing.T) test {
			jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
			assert.FatalError(t, err)
			p, err := generateK8sSA(jwk.Public().Key)
			assert.FatalError(t, err)
			claims := getK8sSAPayload()
			claims.Claims.Issuer = "invalid"
			tok, err := generateK8sSAToken(jwk, claims)
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: tok,
				code:  http.StatusUnauthorized,
				err:   errors.New("k8ssa.authorizeToken; invalid k8sSA token claims: square/go-jose/jwt: validation failed, invalid issuer claim (iss)"),
			}
		},
		"ok": func(t *testing.T) test {
			jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
			assert.FatalError(t, err)
			p, err := generateK8sSA(jwk.Public().Key)
			assert.FatalError(t, err)
			tok, err := generateK8sSAToken(jwk, nil)
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: tok,
			}
		},
	}
	for name, tt := range tests {
		t.Run(name, func(t *testing.T) {
			tc := tt(t)
			if claims, err := tc.p.authorizeToken(tc.token, testAudiences.Sign); err != nil {
				if assert.NotNil(t, tc.err) {
					var sc render.StatusCodedError
					assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface")
					assert.Equals(t, sc.StatusCode(), tc.code)
					assert.HasPrefix(t, err.Error(), tc.err.Error())
				}
			} else {
				if assert.Nil(t, tc.err) {
					assert.NotNil(t, claims)
				}
			}
		})
	}
}

func TestK8sSA_AuthorizeRevoke(t *testing.T) {
	type test struct {
		p     *K8sSA
		token string
		err   error
		code  int
	}
	tests := map[string]func(*testing.T) test{
		"fail/invalid-token": func(t *testing.T) test {
			p, err := generateK8sSA(nil)
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: "foo",
				code:  http.StatusUnauthorized,
				err:   errors.New("k8ssa.AuthorizeRevoke: k8ssa.authorizeToken; error parsing k8sSA token"),
			}
		},
		"ok": func(t *testing.T) test {
			jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
			assert.FatalError(t, err)
			p, err := generateK8sSA(jwk.Public().Key)
			assert.FatalError(t, err)
			tok, err := generateK8sSAToken(jwk, nil)
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: tok,
			}
		},
	}
	for name, tt := range tests {
		t.Run(name, func(t *testing.T) {
			tc := tt(t)
			if err := tc.p.AuthorizeRevoke(context.Background(), tc.token); err != nil {
				var sc render.StatusCodedError
				assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface")
				assert.Equals(t, sc.StatusCode(), tc.code)
				if assert.NotNil(t, tc.err) {
					assert.HasPrefix(t, err.Error(), tc.err.Error())
				}
			} else {
				assert.Nil(t, tc.err)
			}
		})
	}
}

func TestK8sSA_AuthorizeRenew(t *testing.T) {
	now := time.Now().Truncate(time.Second)
	type test struct {
		p    *K8sSA
		cert *x509.Certificate
		err  error
		code int
	}
	tests := map[string]func(*testing.T) test{
		"fail/renew-disabled": func(t *testing.T) test {
			p, err := generateK8sSA(nil)
			assert.FatalError(t, err)
			// disable renewal
			disable := true
			p.Claims = &Claims{DisableRenewal: &disable}
			p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
			assert.FatalError(t, err)
			return test{
				p: p,
				cert: &x509.Certificate{
					NotBefore: now,
					NotAfter:  now.Add(time.Hour),
				},
				code: http.StatusUnauthorized,
				err:  fmt.Errorf("renew is disabled for provisioner '%s'", p.GetName()),
			}
		},
		"ok": func(t *testing.T) test {
			p, err := generateK8sSA(nil)
			assert.FatalError(t, err)
			return test{
				p: p,
				cert: &x509.Certificate{
					NotBefore: now,
					NotAfter:  now.Add(time.Hour),
				},
			}
		},
	}
	for name, tt := range tests {
		t.Run(name, func(t *testing.T) {
			tc := tt(t)
			if err := tc.p.AuthorizeRenew(context.Background(), tc.cert); err != nil {
				var sc render.StatusCodedError
				assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface")
				assert.Equals(t, sc.StatusCode(), tc.code)
				if assert.NotNil(t, tc.err) {
					assert.HasPrefix(t, err.Error(), tc.err.Error())
				}
			} else {
				assert.Nil(t, tc.err)
			}
		})
	}
}

func TestK8sSA_AuthorizeSign(t *testing.T) {
	type test struct {
		p     *K8sSA
		token string
		code  int
		err   error
	}
	tests := map[string]func(*testing.T) test{
		"fail/invalid-token": func(t *testing.T) test {
			p, err := generateK8sSA(nil)
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: "foo",
				code:  http.StatusUnauthorized,
				err:   errors.New("k8ssa.AuthorizeSign: k8ssa.authorizeToken; error parsing k8sSA token"),
			}
		},
		"ok": func(t *testing.T) test {
			jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
			assert.FatalError(t, err)
			p, err := generateK8sSA(jwk.Public().Key)
			assert.FatalError(t, err)
			tok, err := generateK8sSAToken(jwk, nil)
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: tok,
			}
		},
	}
	for name, tt := range tests {
		t.Run(name, func(t *testing.T) {
			tc := tt(t)
			if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil {
				if assert.NotNil(t, tc.err) {
					var sc render.StatusCodedError
					assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface")
					assert.Equals(t, sc.StatusCode(), tc.code)
					assert.HasPrefix(t, err.Error(), tc.err.Error())
				}
			} else {
				if assert.Nil(t, tc.err) {
					if assert.NotNil(t, opts) {
						for _, o := range opts {
							switch v := o.(type) {
							case *K8sSA:
							case certificateOptionsFunc:
							case *provisionerExtensionOption:
								assert.Equals(t, v.Type, TypeK8sSA)
								assert.Equals(t, v.Name, tc.p.GetName())
								assert.Equals(t, v.CredentialID, "")
								assert.Len(t, 0, v.KeyValuePairs)
							case profileDefaultDuration:
								assert.Equals(t, time.Duration(v), tc.p.ctl.Claimer.DefaultTLSCertDuration())
							case defaultPublicKeyValidator:
							case *validityValidator:
								assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
								assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
							case *x509NamePolicyValidator:
								assert.Equals(t, nil, v.policyEngine)
							case *WebhookController:
								assert.Len(t, 0, v.webhooks)
							default:
								assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
							}
						}
						assert.Equals(t, 8, len(opts))
					}
				}
			}
		})
	}
}

func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
	type test struct {
		p     *K8sSA
		token string
		code  int
		err   error
	}
	tests := map[string]func(*testing.T) test{
		"fail/sshCA-disabled": func(t *testing.T) test {
			p, err := generateK8sSA(nil)
			assert.FatalError(t, err)
			// disable sshCA
			disable := false
			p.Claims = &Claims{EnableSSHCA: &disable}
			p.ctl.Claimer, err = NewClaimer(p.Claims, globalProvisionerClaims)
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: "foo",
				code:  http.StatusUnauthorized,
				err:   fmt.Errorf("k8ssa.AuthorizeSSHSign; sshCA is disabled for k8sSA provisioner '%s'", p.GetName()),
			}
		},
		"fail/invalid-token": func(t *testing.T) test {
			p, err := generateK8sSA(nil)
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: "foo",
				code:  http.StatusUnauthorized,
				err:   errors.New("k8ssa.AuthorizeSSHSign: k8ssa.authorizeToken; error parsing k8sSA token"),
			}
		},
		"ok": func(t *testing.T) test {
			jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
			assert.FatalError(t, err)
			p, err := generateK8sSA(jwk.Public().Key)
			assert.FatalError(t, err)
			tok, err := generateK8sSAToken(jwk, nil)
			assert.FatalError(t, err)
			return test{
				p:     p,
				token: tok,
			}
		},
	}
	for name, tt := range tests {
		t.Run(name, func(t *testing.T) {
			tc := tt(t)
			if opts, err := tc.p.AuthorizeSSHSign(context.Background(), tc.token); err != nil {
				if assert.NotNil(t, tc.err) {
					var sc render.StatusCodedError
					assert.Fatal(t, errors.As(err, &sc), "error does not implement StatusCodedError interface")
					assert.Equals(t, sc.StatusCode(), tc.code)
					assert.HasPrefix(t, err.Error(), tc.err.Error())
				}
			} else {
				if assert.Nil(t, tc.err) {
					if assert.NotNil(t, opts) {
						assert.Len(t, 9, opts)
						for _, o := range opts {
							switch v := o.(type) {
							case Interface:
							case sshCertificateOptionsFunc:
							case *sshCertOptionsRequireValidator:
								assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true})
							case *sshCertValidityValidator:
								assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
							case *sshDefaultPublicKeyValidator:
							case *sshCertDefaultValidator:
							case *sshDefaultDuration:
								assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
							case *sshNamePolicyValidator:
								assert.Equals(t, nil, v.userPolicyEngine)
								assert.Equals(t, nil, v.hostPolicyEngine)
							case *WebhookController:
								assert.Len(t, 0, v.webhooks)
							default:
								assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
							}
						}
					}
				}
			}
		})
	}
}