package tree

import (
	"context"
	"crypto/ecdsa"
	"crypto/sha256"
	"encoding/hex"
	"errors"
	"testing"

	"git.frostfs.info/TrueCloudLab/frostfs-contract/frostfsid/client"
	containercore "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/container"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/netmap"
	checkercore "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/common/ape"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/util/logger/test"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/ape"
	aclV2 "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/api/acl"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/bearer"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/acl"
	cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
	cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test"
	netmapSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
	"git.frostfs.info/TrueCloudLab/policy-engine/pkg/chain"
	"git.frostfs.info/TrueCloudLab/policy-engine/pkg/engine"
	"git.frostfs.info/TrueCloudLab/policy-engine/pkg/engine/inmemory"
	"git.frostfs.info/TrueCloudLab/policy-engine/schema/native"
	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
	"github.com/nspcc-dev/neo-go/pkg/util"
	"github.com/stretchr/testify/require"
)

type dummyNetmapSource struct {
	netmap.Source
}

type dummySubjectProvider struct {
	subjects map[util.Uint160]client.SubjectExtended
}

func (s dummySubjectProvider) GetSubject(addr util.Uint160) (*client.Subject, error) {
	res := s.subjects[addr]
	return &client.Subject{
		PrimaryKey:     res.PrimaryKey,
		AdditionalKeys: res.AdditionalKeys,
		Namespace:      res.Namespace,
		Name:           res.Name,
		KV:             res.KV,
	}, nil
}

func (s dummySubjectProvider) GetSubjectExtended(addr util.Uint160) (*client.SubjectExtended, error) {
	res := s.subjects[addr]
	return &res, nil
}

type dummyEpochSource struct {
	epoch uint64
}

func (s dummyEpochSource) CurrentEpoch() uint64 {
	return s.epoch
}

type dummyContainerSource map[string]*containercore.Container

func (s dummyContainerSource) List() ([]cid.ID, error) {
	res := make([]cid.ID, 0, len(s))
	var cnr cid.ID

	for cidStr := range s {
		err := cnr.DecodeString(cidStr)
		if err != nil {
			return nil, err
		}

		res = append(res, cnr)
	}

	return res, nil
}

func (s dummyContainerSource) Get(id cid.ID) (*containercore.Container, error) {
	cnt, ok := s[id.String()]
	if !ok {
		return nil, errors.New("container not found")
	}
	return cnt, nil
}

func (s dummyContainerSource) DeletionInfo(id cid.ID) (*containercore.DelInfo, error) {
	return &containercore.DelInfo{}, nil
}

func testContainer(owner user.ID) container.Container {
	var r netmapSDK.ReplicaDescriptor
	r.SetNumberOfObjects(1)

	var pp netmapSDK.PlacementPolicy
	pp.AddReplicas(r)

	var cnt container.Container
	cnt.SetOwner(owner)
	cnt.SetPlacementPolicy(pp)

	return cnt
}

const currentEpoch = 123

func TestMessageSign(t *testing.T) {
	privs := make([]*keys.PrivateKey, 4)
	for i := range privs {
		p, err := keys.NewPrivateKey()
		require.NoError(t, err)
		privs[i] = p
	}

	cid1 := cidtest.ID()
	cid2 := cidtest.ID()

	var ownerID user.ID
	user.IDFromKey(&ownerID, (ecdsa.PublicKey)(*privs[0].PublicKey()))

	cnr := &containercore.Container{
		Value: testContainer(ownerID),
	}

	e := inmemory.NewInMemoryLocalOverrides()
	e.MorphRuleChainStorage().AddMorphRuleChain(chain.Ingress, engine.Target{
		Type: engine.Container,
		Name: cid1.EncodeToString(),
	}, testChain(privs[0].PublicKey(), privs[1].PublicKey()))
	frostfsidProvider := dummySubjectProvider{
		subjects: make(map[util.Uint160]client.SubjectExtended),
	}

	s := &Service{
		cfg: cfg{
			log:      test.NewLogger(t),
			key:      &privs[0].PrivateKey,
			nmSource: dummyNetmapSource{},
			cnrSource: dummyContainerSource{
				cid1.String(): cnr,
			},
			frostfsidSubjectProvider: frostfsidProvider,
			state:                    dummyEpochSource{epoch: currentEpoch},
		},
		apeChecker: checkercore.New(e.LocalStorage(), e.MorphRuleChainStorage(), frostfsidProvider, dummyEpochSource{}),
	}

	rawCID1 := make([]byte, sha256.Size)
	cid1.Encode(rawCID1)

	req := &MoveRequest{
		Body: &MoveRequest_Body{
			ContainerId: rawCID1,
			ParentId:    1,
			NodeId:      2,
			Meta: []KeyValue{
				{Key: "kkk", Value: []byte("vvv")},
			},
		},
	}

	op := acl.OpObjectPut
	cnr.Value.SetBasicACL(acl.PublicRW)

	t.Run("missing signature, no panic", func(t *testing.T) {
		require.Error(t, s.verifyClient(context.Background(), req, cid2, nil, op))
	})

	require.NoError(t, SignMessage(req, &privs[0].PrivateKey))
	require.NoError(t, s.verifyClient(context.Background(), req, cid1, nil, op))

	t.Run("invalid CID", func(t *testing.T) {
		require.Error(t, s.verifyClient(context.Background(), req, cid2, nil, op))
	})

	cnr.Value.SetBasicACL(acl.Private)

	t.Run("extension disabled", func(t *testing.T) {
		require.NoError(t, SignMessage(req, &privs[0].PrivateKey))
		require.Error(t, s.verifyClient(context.Background(), req, cid2, nil, op))
	})

	t.Run("invalid key", func(t *testing.T) {
		require.NoError(t, SignMessage(req, &privs[1].PrivateKey))
		require.Error(t, s.verifyClient(context.Background(), req, cid1, nil, op))
	})

	t.Run("bearer", func(t *testing.T) {
		bACL := acl.PrivateExtended
		bACL.AllowBearerRules(op)
		cnr.Value.SetBasicACL(bACL)

		bACL.DisableExtension()

		t.Run("invalid bearer", func(t *testing.T) {
			req.Body.BearerToken = []byte{0xFF}
			require.NoError(t, SignMessage(req, &privs[0].PrivateKey))
			require.Error(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectPut))
		})

		t.Run("invalid bearer CID", func(t *testing.T) {
			bt := testBearerToken(cid2, privs[1].PublicKey(), privs[2].PublicKey())
			require.NoError(t, bt.Sign(privs[0].PrivateKey))
			req.Body.BearerToken = bt.Marshal()

			require.NoError(t, SignMessage(req, &privs[1].PrivateKey))
			require.Error(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectPut))
		})
		t.Run("invalid bearer owner", func(t *testing.T) {
			bt := testBearerToken(cid1, privs[1].PublicKey(), privs[2].PublicKey())
			require.NoError(t, bt.Sign(privs[1].PrivateKey))
			req.Body.BearerToken = bt.Marshal()

			require.NoError(t, SignMessage(req, &privs[1].PrivateKey))
			require.Error(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectPut))
		})
		t.Run("invalid bearer signature", func(t *testing.T) {
			bt := testBearerToken(cid1, privs[1].PublicKey(), privs[2].PublicKey())
			require.NoError(t, bt.Sign(privs[0].PrivateKey))

			var bv2 aclV2.BearerToken
			bt.WriteToV2(&bv2)
			bv2.GetSignature().SetSign([]byte{1, 2, 3})
			req.Body.BearerToken = bv2.StableMarshal(nil)

			require.NoError(t, SignMessage(req, &privs[1].PrivateKey))
			require.Error(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectPut))
		})

		t.Run("impersonate", func(t *testing.T) {
			cnr.Value.SetBasicACL(acl.PublicRWExtended)
			var bt bearer.Token
			bt.SetImpersonate(true)

			require.NoError(t, bt.Sign(privs[1].PrivateKey))
			req.Body.BearerToken = bt.Marshal()

			require.NoError(t, SignMessage(req, &privs[0].PrivateKey))
			require.Error(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectPut))
			require.NoError(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectGet))
		})

		bt := testBearerToken(cid1, privs[1].PublicKey(), privs[2].PublicKey())
		require.NoError(t, bt.Sign(privs[0].PrivateKey))
		req.Body.BearerToken = bt.Marshal()
		cnr.Value.SetBasicACL(acl.PublicRWExtended)

		t.Run("put and get", func(t *testing.T) {
			require.NoError(t, SignMessage(req, &privs[1].PrivateKey))
			require.NoError(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectPut))
			require.NoError(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectGet))
		})
		t.Run("only get", func(t *testing.T) {
			require.NoError(t, SignMessage(req, &privs[2].PrivateKey))
			require.Error(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectPut))
			require.NoError(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectGet))
		})
		t.Run("none", func(t *testing.T) {
			require.NoError(t, SignMessage(req, &privs[3].PrivateKey))
			require.Error(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectPut))
			require.Error(t, s.verifyClient(context.Background(), req, cid1, req.GetBody().GetBearerToken(), acl.OpObjectGet))
		})
	})
}

func testBearerToken(cid cid.ID, forPutGet, forGet *keys.PublicKey) bearer.Token {
	var b bearer.Token
	b.SetExp(currentEpoch + 1)
	b.SetAPEOverride(bearer.APEOverride{
		Target: ape.ChainTarget{
			TargetType: ape.TargetTypeContainer,
			Name:       cid.EncodeToString(),
		},
		Chains: []ape.Chain{{Raw: testChain(forPutGet, forGet).Bytes()}},
	})

	return b
}

func testChain(forPutGet, forGet *keys.PublicKey) *chain.Chain {
	ruleGet := chain.Rule{
		Status:    chain.Allow,
		Resources: chain.Resources{Names: []string{native.ResourceFormatAllObjects}},
		Actions:   chain.Actions{Names: []string{native.MethodGetObject}},
		Any:       true,
		Condition: []chain.Condition{
			{
				Op:    chain.CondStringEquals,
				Kind:  chain.KindRequest,
				Key:   native.PropertyKeyActorPublicKey,
				Value: hex.EncodeToString(forPutGet.Bytes()),
			},
			{
				Op:    chain.CondStringEquals,
				Kind:  chain.KindRequest,
				Key:   native.PropertyKeyActorPublicKey,
				Value: hex.EncodeToString(forGet.Bytes()),
			},
		},
	}
	rulePut := chain.Rule{
		Status:    chain.Allow,
		Resources: chain.Resources{Names: []string{native.ResourceFormatAllObjects}},
		Actions:   chain.Actions{Names: []string{native.MethodPutObject}},
		Any:       true,
		Condition: []chain.Condition{
			{
				Op:    chain.CondStringEquals,
				Kind:  chain.KindRequest,
				Key:   native.PropertyKeyActorPublicKey,
				Value: hex.EncodeToString(forPutGet.Bytes()),
			},
		},
	}

	return &chain.Chain{
		Rules: []chain.Rule{
			ruleGet,
			rulePut,
		},
	}
}