package getsvc

import (
	"context"
	"crypto/ecdsa"
	"errors"
	"io"
	"sync"

	objectV2 "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/object"
	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/rpc"
	rpcclient "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/rpc/client"
	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/session"
	"git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/signature"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/client"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/network"
	"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object/internal"
	internalclient "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object/internal/client"
	"git.frostfs.info/TrueCloudLab/frostfs-observability/tracing"
	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	"go.opentelemetry.io/otel/attribute"
	"go.opentelemetry.io/otel/trace"
)

type getRequestForwarder struct {
	OnceResign        sync.Once
	OnceHeaderSending sync.Once
	GlobalProgress    int
	Key               *ecdsa.PrivateKey
	Request           *objectV2.GetRequest
	Stream            *streamObjectWriter
}

func (f *getRequestForwarder) forwardRequestToNode(ctx context.Context, addr network.Address, c client.MultiAddressClient, pubkey []byte) (*objectSDK.Object, error) {
	ctx, span := tracing.StartSpanFromContext(ctx, "getRequestForwarder.forwardRequestToNode",
		trace.WithAttributes(attribute.String("address", addr.String())),
	)
	defer span.End()

	var err error

	// once compose and resign forwarding request
	f.OnceResign.Do(func() {
		// compose meta header of the local server
		metaHdr := new(session.RequestMetaHeader)
		metaHdr.SetTTL(f.Request.GetMetaHeader().GetTTL() - 1)
		// TODO: #1165 think how to set the other fields
		metaHdr.SetOrigin(f.Request.GetMetaHeader())
		writeCurrentVersion(metaHdr)
		f.Request.SetMetaHeader(metaHdr)
		err = signature.SignServiceMessage(f.Key, f.Request)
	})

	if err != nil {
		return nil, err
	}

	getStream, err := f.openStream(ctx, addr, c)
	if err != nil {
		return nil, err
	}
	return nil, f.readStream(ctx, c, getStream, pubkey)
}

func (f *getRequestForwarder) verifyResponse(resp *objectV2.GetResponse, pubkey []byte) error {
	// verify response key
	if err := internal.VerifyResponseKeyV2(pubkey, resp); err != nil {
		return err
	}

	// verify response structure
	if err := signature.VerifyServiceMessage(resp); err != nil {
		return errResponseVerificationFailed(err)
	}

	return checkStatus(resp.GetMetaHeader().GetStatus())
}

func (f *getRequestForwarder) writeHeader(ctx context.Context, v *objectV2.GetObjectPartInit) error {
	obj := new(objectV2.Object)

	obj.SetObjectID(v.GetObjectID())
	obj.SetSignature(v.GetSignature())
	obj.SetHeader(v.GetHeader())

	var err error
	f.OnceHeaderSending.Do(func() {
		err = f.Stream.WriteHeader(ctx, objectSDK.NewFromV2(obj))
	})
	if err != nil {
		return errCouldNotWriteObjHeader(err)
	}
	return nil
}

func (f *getRequestForwarder) openStream(ctx context.Context, addr network.Address, c client.MultiAddressClient) (*rpc.GetResponseReader, error) {
	var getStream *rpc.GetResponseReader
	err := c.RawForAddress(ctx, addr, func(cli *rpcclient.Client) error {
		var e error
		getStream, e = rpc.GetObject(cli, f.Request, rpcclient.WithContext(ctx))
		return e
	})
	if err != nil {
		return nil, errStreamOpenningFailed(err)
	}
	return getStream, nil
}

func (f *getRequestForwarder) readStream(ctx context.Context, c client.MultiAddressClient, getStream *rpc.GetResponseReader, pubkey []byte) error {
	var (
		headWas       bool
		resp          = new(objectV2.GetResponse)
		localProgress int
	)

	for {
		// receive message from server stream
		err := getStream.Read(resp)
		if err != nil {
			if errors.Is(err, io.EOF) {
				if !headWas {
					return io.ErrUnexpectedEOF
				}

				break
			}

			internalclient.ReportError(c, err)
			return errReadingResponseFailed(err)
		}

		if err := f.verifyResponse(resp, pubkey); err != nil {
			return err
		}

		switch v := resp.GetBody().GetObjectPart().(type) {
		default:
			return errUnexpectedObjectPart(v)
		case *objectV2.GetObjectPartInit:
			if headWas {
				return errWrongMessageSeq
			}
			headWas = true
			if err := f.writeHeader(ctx, v); err != nil {
				return err
			}
		case *objectV2.GetObjectPartChunk:
			if !headWas {
				return errWrongMessageSeq
			}

			origChunk := v.GetChunk()

			chunk := chunkToSend(f.GlobalProgress, localProgress, origChunk)
			if len(chunk) == 0 {
				localProgress += len(origChunk)
				continue
			}

			if err = f.Stream.WriteChunk(ctx, chunk); err != nil {
				return errCouldNotWriteObjChunk("Get", err)
			}

			localProgress += len(origChunk)
			f.GlobalProgress += len(chunk)
		case *objectV2.SplitInfo:
			si := objectSDK.NewSplitInfoFromV2(v)
			return objectSDK.NewSplitInfoError(si)
		}
	}
	return nil
}