package db

import (
	"crypto/x509"
	"errors"
	"math/big"
	"reflect"
	"testing"

	"github.com/smallstep/assert"
	"github.com/smallstep/certificates/authority/provisioner"
	"github.com/smallstep/nosql"
	"github.com/smallstep/nosql/database"
)

func TestIsRevoked(t *testing.T) {
	tests := map[string]struct {
		key       string
		db        *DB
		isRevoked bool
		err       error
	}{
		"false/nil db": {
			key: "sn",
		},
		"false/ErrNotFound": {
			key: "sn",
			db:  &DB{&MockNoSQLDB{Err: database.ErrNotFound, Ret1: nil}, true},
		},
		"error/checking bucket": {
			key: "sn",
			db:  &DB{&MockNoSQLDB{Err: errors.New("force"), Ret1: nil}, true},
			err: errors.New("error checking revocation bucket: force"),
		},
		"true": {
			key:       "sn",
			db:        &DB{&MockNoSQLDB{Ret1: []byte("value")}, true},
			isRevoked: true,
		},
	}
	for name, tc := range tests {
		t.Run(name, func(t *testing.T) {
			isRevoked, err := tc.db.IsRevoked(tc.key)
			if err != nil {
				if assert.NotNil(t, tc.err) {
					assert.HasPrefix(t, tc.err.Error(), err.Error())
				}
			} else {
				assert.Nil(t, tc.err)
				assert.Fatal(t, isRevoked == tc.isRevoked)
			}
		})
	}
}

func TestRevoke(t *testing.T) {
	tests := map[string]struct {
		rci *RevokedCertificateInfo
		db  *DB
		err error
	}{
		"error/force isRevoked": {
			rci: &RevokedCertificateInfo{Serial: "sn"},
			db: &DB{&MockNoSQLDB{
				MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
					return nil, false, errors.New("force")
				},
			}, true},
			err: errors.New("error AuthDB CmpAndSwap: force"),
		},
		"error/was already revoked": {
			rci: &RevokedCertificateInfo{Serial: "sn"},
			db: &DB{&MockNoSQLDB{
				MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
					return []byte("foo"), false, nil
				},
			}, true},
			err: ErrAlreadyExists,
		},
		"ok": {
			rci: &RevokedCertificateInfo{Serial: "sn"},
			db: &DB{&MockNoSQLDB{
				MCmpAndSwap: func(bucket, sn, old, newval []byte) ([]byte, bool, error) {
					return []byte("foo"), true, nil
				},
			}, true},
		},
	}
	for name, tc := range tests {
		t.Run(name, func(t *testing.T) {
			if err := tc.db.Revoke(tc.rci); err != nil {
				if assert.NotNil(t, tc.err) {
					assert.HasPrefix(t, tc.err.Error(), err.Error())
				}
			} else {
				assert.Nil(t, tc.err)
			}
		})
	}
}

func TestUseToken(t *testing.T) {
	type result struct {
		err error
		ok  bool
	}
	tests := map[string]struct {
		id, tok string
		db      *DB
		want    result
	}{
		"fail/force-CmpAndSwap-error": {
			id:  "id",
			tok: "token",
			db: &DB{&MockNoSQLDB{
				MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
					return nil, false, errors.New("force")
				},
			}, true},
			want: result{
				ok:  false,
				err: errors.New("error storing used token used_ott/id"),
			},
		},
		"fail/CmpAndSwap-already-exists": {
			id:  "id",
			tok: "token",
			db: &DB{&MockNoSQLDB{
				MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
					return []byte("foo"), false, nil
				},
			}, true},
			want: result{
				ok: false,
			},
		},
		"ok/cmpAndSwap-success": {
			id:  "id",
			tok: "token",
			db: &DB{&MockNoSQLDB{
				MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) {
					return []byte("bar"), true, nil
				},
			}, true},
			want: result{
				ok: true,
			},
		},
	}
	for name, tc := range tests {
		t.Run(name, func(t *testing.T) {
			switch ok, err := tc.db.UseToken(tc.id, tc.tok); {
			case err != nil:
				if assert.NotNil(t, tc.want.err) {
					assert.HasPrefix(t, err.Error(), tc.want.err.Error())
				}
				assert.False(t, ok)
			case ok:
				assert.True(t, tc.want.ok)
			default:
				assert.False(t, tc.want.ok)
			}
		})
	}
}

func TestDB_StoreCertificateChain(t *testing.T) {
	p := &provisioner.JWK{
		ID:   "some-id",
		Name: "admin",
		Type: "JWK",
	}
	chain := []*x509.Certificate{
		{Raw: []byte("the certificate"), SerialNumber: big.NewInt(1234)},
	}
	type fields struct {
		DB   nosql.DB
		isUp bool
	}
	type args struct {
		p     provisioner.Interface
		chain []*x509.Certificate
	}
	tests := []struct {
		name    string
		fields  fields
		args    args
		wantErr bool
	}{
		{"ok", fields{&MockNoSQLDB{
			MUpdate: func(tx *database.Tx) error {
				if len(tx.Operations) != 2 {
					t.Fatal("unexpected number of operations")
				}
				assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket)
				assert.Equals(t, []byte("1234"), tx.Operations[0].Key)
				assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value)
				assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket)
				assert.Equals(t, []byte("1234"), tx.Operations[1].Key)
				assert.Equals(t, []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), tx.Operations[1].Value)
				return nil
			},
		}, true}, args{p, chain}, false},
		{"ok no provisioner", fields{&MockNoSQLDB{
			MUpdate: func(tx *database.Tx) error {
				if len(tx.Operations) != 2 {
					t.Fatal("unexpected number of operations")
				}
				assert.Equals(t, []byte("x509_certs"), tx.Operations[0].Bucket)
				assert.Equals(t, []byte("1234"), tx.Operations[0].Key)
				assert.Equals(t, []byte("the certificate"), tx.Operations[0].Value)
				assert.Equals(t, []byte("x509_certs_data"), tx.Operations[1].Bucket)
				assert.Equals(t, []byte("1234"), tx.Operations[1].Key)
				assert.Equals(t, []byte(`{}`), tx.Operations[1].Value)
				return nil
			},
		}, true}, args{nil, chain}, false},
		{"fail store certificate", fields{&MockNoSQLDB{
			MUpdate: func(tx *database.Tx) error {
				return errors.New("test error")
			},
		}, true}, args{p, chain}, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			d := &DB{
				DB:   tt.fields.DB,
				isUp: tt.fields.isUp,
			}
			if err := d.StoreCertificateChain(tt.args.p, tt.args.chain...); (err != nil) != tt.wantErr {
				t.Errorf("DB.StoreCertificateChain() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}

func TestDB_GetCertificateData(t *testing.T) {
	type fields struct {
		DB   nosql.DB
		isUp bool
	}
	type args struct {
		serialNumber string
	}
	tests := []struct {
		name    string
		fields  fields
		args    args
		want    *CertificateData
		wantErr bool
	}{
		{"ok", fields{&MockNoSQLDB{
			MGet: func(bucket, key []byte) ([]byte, error) {
				assert.Equals(t, bucket, []byte("x509_certs_data"))
				assert.Equals(t, key, []byte("1234"))
				return []byte(`{"provisioner":{"id":"some-id","name":"admin","type":"JWK"}}`), nil
			},
		}, true}, args{"1234"}, &CertificateData{
			Provisioner: &ProvisionerData{
				ID: "some-id", Name: "admin", Type: "JWK",
			},
		}, false},
		{"fail not found", fields{&MockNoSQLDB{
			MGet: func(bucket, key []byte) ([]byte, error) {
				return nil, database.ErrNotFound
			},
		}, true}, args{"1234"}, nil, true},
		{"fail db", fields{&MockNoSQLDB{
			MGet: func(bucket, key []byte) ([]byte, error) {
				return nil, errors.New("an error")
			},
		}, true}, args{"1234"}, nil, true},
		{"fail unmarshal", fields{&MockNoSQLDB{
			MGet: func(bucket, key []byte) ([]byte, error) {
				return []byte(`{"bad-json"}`), nil
			},
		}, true}, args{"1234"}, nil, true},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			db := &DB{
				DB:   tt.fields.DB,
				isUp: tt.fields.isUp,
			}
			got, err := db.GetCertificateData(tt.args.serialNumber)
			if (err != nil) != tt.wantErr {
				t.Errorf("DB.GetCertificateData() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("DB.GetCertificateData() = %v, want %v", got, tt.want)
			}
		})
	}
}