package provisioner

import (
	"context"
	"crypto/x509"
	"fmt"
	"reflect"
	"testing"
	"time"

	"golang.org/x/crypto/ssh"
)

var trueValue = true

func mustClaimer(t *testing.T, claims *Claims, global Claims) *Claimer {
	t.Helper()
	c, err := NewClaimer(claims, global)
	if err != nil {
		t.Fatal(err)
	}
	return c
}
func mustDuration(t *testing.T, s string) *Duration {
	t.Helper()
	d, err := NewDuration(s)
	if err != nil {
		t.Fatal(err)
	}
	return d
}

func TestNewController(t *testing.T) {
	type args struct {
		p      Interface
		claims *Claims
		config Config
	}
	tests := []struct {
		name    string
		args    args
		want    *Controller
		wantErr bool
	}{
		{"ok", args{&JWK{}, nil, Config{
			Claims:    globalProvisionerClaims,
			Audiences: testAudiences,
		}}, &Controller{
			Interface: &JWK{},
			Audiences: &testAudiences,
			Claimer:   mustClaimer(t, nil, globalProvisionerClaims),
		}, false},
		{"ok with claims", args{&JWK{}, &Claims{
			DisableRenewal: &defaultDisableRenewal,
		}, Config{
			Claims:    globalProvisionerClaims,
			Audiences: testAudiences,
		}}, &Controller{
			Interface: &JWK{},
			Audiences: &testAudiences,
			Claimer: mustClaimer(t, &Claims{
				DisableRenewal: &defaultDisableRenewal,
			}, globalProvisionerClaims),
		}, false},
		{"fail claimer", args{&JWK{}, &Claims{
			MinTLSDur: mustDuration(t, "24h"),
			MaxTLSDur: mustDuration(t, "2h"),
		}, Config{
			Claims:    globalProvisionerClaims,
			Audiences: testAudiences,
		}}, nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := NewController(tt.args.p, tt.args.claims, tt.args.config)
			if (err != nil) != tt.wantErr {
				t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("NewController() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestController_GetIdentity(t *testing.T) {
	ctx := context.Background()
	type fields struct {
		Interface    Interface
		IdentityFunc GetIdentityFunc
	}
	type args struct {
		ctx   context.Context
		email string
	}
	tests := []struct {
		name    string
		fields  fields
		args    args
		want    *Identity
		wantErr bool
	}{
		{"ok", fields{&OIDC{}, nil}, args{ctx, "jane@doe.org"}, &Identity{
			Usernames: []string{"jane", "jane@doe.org"},
		}, false},
		{"ok custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
			return &Identity{Usernames: []string{"jane"}}, nil
		}}, args{ctx, "jane@doe.org"}, &Identity{
			Usernames: []string{"jane"},
		}, false},
		{"fail provisioner", fields{&JWK{}, nil}, args{ctx, "jane@doe.org"}, nil, true},
		{"fail custom", fields{&OIDC{}, func(ctx context.Context, p Interface, email string) (*Identity, error) {
			return nil, fmt.Errorf("an error")
		}}, args{ctx, "jane@doe.org"}, nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			c := &Controller{
				Interface:    tt.fields.Interface,
				IdentityFunc: tt.fields.IdentityFunc,
			}
			got, err := c.GetIdentity(tt.args.ctx, tt.args.email)
			if (err != nil) != tt.wantErr {
				t.Errorf("Controller.GetIdentity() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("Controller.GetIdentity() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestController_AuthorizeRenew(t *testing.T) {
	ctx := context.Background()
	now := time.Now().Truncate(time.Second)
	type fields struct {
		Interface          Interface
		Claimer            *Claimer
		AuthorizeRenewFunc AuthorizeRenewFunc
	}
	type args struct {
		ctx  context.Context
		cert *x509.Certificate
	}
	tests := []struct {
		name    string
		fields  fields
		args    args
		wantErr bool
	}{
		{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
			NotBefore: now,
			NotAfter:  now.Add(time.Hour),
		}}, false},
		{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
			return nil
		}}, args{ctx, &x509.Certificate{
			NotBefore: now,
			NotAfter:  now.Add(time.Hour),
		}}, false},
		{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
			return nil
		}}, args{ctx, &x509.Certificate{
			NotBefore: now,
			NotAfter:  now.Add(time.Hour),
		}}, false},
		{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
			NotBefore: now.Add(-time.Hour),
			NotAfter:  now.Add(-time.Minute),
		}}, false},
		{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
			NotBefore: now,
			NotAfter:  now.Add(time.Hour),
		}}, true},
		{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
			NotBefore: now.Add(time.Hour),
			NotAfter:  now.Add(2 * time.Hour),
		}}, true},
		{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
			NotBefore: now.Add(-time.Hour),
			NotAfter:  now.Add(-time.Minute),
		}}, true},
		{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
			return fmt.Errorf("an error")
		}}, args{ctx, &x509.Certificate{
			NotBefore: now,
			NotAfter:  now.Add(time.Hour),
		}}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			c := &Controller{
				Interface:          tt.fields.Interface,
				Claimer:            tt.fields.Claimer,
				AuthorizeRenewFunc: tt.fields.AuthorizeRenewFunc,
			}
			if err := c.AuthorizeRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
				t.Errorf("Controller.AuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func TestController_AuthorizeSSHRenew(t *testing.T) {
	ctx := context.Background()
	now := time.Now()
	type fields struct {
		Interface             Interface
		Claimer               *Claimer
		AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
	}
	type args struct {
		ctx  context.Context
		cert *ssh.Certificate
	}
	tests := []struct {
		name    string
		fields  fields
		args    args
		wantErr bool
	}{
		{"ok", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
			ValidAfter:  uint64(now.Unix()),
			ValidBefore: uint64(now.Add(time.Hour).Unix()),
		}}, false},
		{"ok custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
			return nil
		}}, args{ctx, &ssh.Certificate{
			ValidAfter:  uint64(now.Unix()),
			ValidBefore: uint64(now.Add(time.Hour).Unix()),
		}}, false},
		{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
			return nil
		}}, args{ctx, &ssh.Certificate{
			ValidAfter:  uint64(now.Unix()),
			ValidBefore: uint64(now.Add(time.Hour).Unix()),
		}}, false},
		{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
			ValidAfter:  uint64(now.Add(-time.Hour).Unix()),
			ValidBefore: uint64(now.Add(-time.Minute).Unix()),
		}}, false},
		{"fail disabled", fields{&JWK{}, mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
			ValidAfter:  uint64(now.Unix()),
			ValidBefore: uint64(now.Add(time.Hour).Unix()),
		}}, true},
		{"fail not yet valid", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
			ValidAfter:  uint64(now.Add(time.Hour).Unix()),
			ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
		}}, true},
		{"fail expired", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
			ValidAfter:  uint64(now.Add(-time.Hour).Unix()),
			ValidBefore: uint64(now.Add(-time.Minute).Unix()),
		}}, true},
		{"fail custom", fields{&JWK{}, mustClaimer(t, nil, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
			return fmt.Errorf("an error")
		}}, args{ctx, &ssh.Certificate{
			ValidAfter:  uint64(now.Unix()),
			ValidBefore: uint64(now.Add(time.Hour).Unix()),
		}}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			c := &Controller{
				Interface:             tt.fields.Interface,
				Claimer:               tt.fields.Claimer,
				AuthorizeSSHRenewFunc: tt.fields.AuthorizeSSHRenewFunc,
			}
			if err := c.AuthorizeSSHRenew(tt.args.ctx, tt.args.cert); (err != nil) != tt.wantErr {
				t.Errorf("Controller.AuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func TestDefaultAuthorizeRenew(t *testing.T) {
	ctx := context.Background()
	now := time.Now().Truncate(time.Second)
	type args struct {
		ctx  context.Context
		p    *Controller
		cert *x509.Certificate
	}
	tests := []struct {
		name    string
		args    args
		wantErr bool
	}{
		{"ok", args{ctx, &Controller{
			Interface: &JWK{},
			Claimer:   mustClaimer(t, nil, globalProvisionerClaims),
		}, &x509.Certificate{
			NotBefore: now,
			NotAfter:  now.Add(time.Hour),
		}}, false},
		{"ok renew after expiry", args{ctx, &Controller{
			Interface: &JWK{},
			Claimer:   mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
		}, &x509.Certificate{
			NotBefore: now.Add(-time.Hour),
			NotAfter:  now.Add(-time.Minute),
		}}, false},
		{"fail disabled", args{ctx, &Controller{
			Interface: &JWK{},
			Claimer:   mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
		}, &x509.Certificate{
			NotBefore: now,
			NotAfter:  now.Add(time.Hour),
		}}, true},
		{"fail not yet valid", args{ctx, &Controller{
			Interface: &JWK{},
			Claimer:   mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
		}, &x509.Certificate{
			NotBefore: now.Add(time.Hour),
			NotAfter:  now.Add(2 * time.Hour),
		}}, true},
		{"fail expired", args{ctx, &Controller{
			Interface: &JWK{},
			Claimer:   mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
		}, &x509.Certificate{
			NotBefore: now.Add(-time.Hour),
			NotAfter:  now.Add(-time.Minute),
		}}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if err := DefaultAuthorizeRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
				t.Errorf("DefaultAuthorizeRenew() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func TestDefaultAuthorizeSSHRenew(t *testing.T) {
	ctx := context.Background()
	now := time.Now()
	type args struct {
		ctx  context.Context
		p    *Controller
		cert *ssh.Certificate
	}
	tests := []struct {
		name    string
		args    args
		wantErr bool
	}{
		{"ok", args{ctx, &Controller{
			Interface: &JWK{},
			Claimer:   mustClaimer(t, nil, globalProvisionerClaims),
		}, &ssh.Certificate{
			ValidAfter:  uint64(now.Unix()),
			ValidBefore: uint64(now.Add(time.Hour).Unix()),
		}}, false},
		{"ok renew after expiry", args{ctx, &Controller{
			Interface: &JWK{},
			Claimer:   mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims),
		}, &ssh.Certificate{
			ValidAfter:  uint64(now.Add(-time.Hour).Unix()),
			ValidBefore: uint64(now.Add(-time.Minute).Unix()),
		}}, false},
		{"fail disabled", args{ctx, &Controller{
			Interface: &JWK{},
			Claimer:   mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
		}, &ssh.Certificate{
			ValidAfter:  uint64(now.Unix()),
			ValidBefore: uint64(now.Add(time.Hour).Unix()),
		}}, true},
		{"fail not yet valid", args{ctx, &Controller{
			Interface: &JWK{},
			Claimer:   mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
		}, &ssh.Certificate{
			ValidAfter:  uint64(now.Add(time.Hour).Unix()),
			ValidBefore: uint64(now.Add(2 * time.Hour).Unix()),
		}}, true},
		{"fail expired", args{ctx, &Controller{
			Interface: &JWK{},
			Claimer:   mustClaimer(t, &Claims{DisableRenewal: &trueValue}, globalProvisionerClaims),
		}, &ssh.Certificate{
			ValidAfter:  uint64(now.Add(-time.Hour).Unix()),
			ValidBefore: uint64(now.Add(-time.Minute).Unix()),
		}}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if err := DefaultAuthorizeSSHRenew(tt.args.ctx, tt.args.p, tt.args.cert); (err != nil) != tt.wantErr {
				t.Errorf("DefaultAuthorizeSSHRenew() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}