package patcher

import (
	"context"
	"errors"
	"fmt"
	"io"

	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/transformer"
)

var (
	ErrOffsetExceedsSize       = errors.New("patch offset exceeds object size")
	ErrInvalidPatchOffsetOrder = errors.New("invalid patch offset order")
	ErrPayloadPatchIsNil       = errors.New("nil payload patch")
	ErrAttrPatchAlreadyApplied = errors.New("attribute patch already applied")
)

// PatchRes is the result of patch application.
type PatchRes struct {
	AccessIdentifiers *transformer.AccessIdentifiers
}

// PatchApplier is the interface that provides method to apply header and payload patches.
type PatchApplier interface {
	// ApplyAttributesPatch applies the patch only for the object's attributes.
	//
	// ApplyAttributesPatch can't be invoked few times, otherwise it returns `ErrAttrPatchAlreadyApplied` error.
	//
	// The call is idempotent for the original header if it's invoked with empty `newAttrs` and
	// `replaceAttrs = false`.
	ApplyAttributesPatch(ctx context.Context, newAttrs []objectSDK.Attribute, replaceAttrs bool) error

	// ApplyPayloadPatch applies the patch for the object's payload.
	//
	// ApplyPayloadPatch returns `ErrPayloadPatchIsNil` error if patch is nil.
	ApplyPayloadPatch(ctx context.Context, payloadPatch *objectSDK.PayloadPatch) error

	// Close closes PatchApplier when the patch stream is over.
	Close(context.Context) (PatchRes, error)
}

// RangeProvider is the interface that provides a method to get original object payload
// by a given range.
type RangeProvider interface {
	// GetRange reads an original object payload by the given range.
	// The method returns io.Reader over the data range only. This means if the data is read out,
	// then GetRange has to be invoked to provide reader over the next range.
	GetRange(ctx context.Context, rng *objectSDK.Range) io.Reader
}

type patcher struct {
	rangeProvider RangeProvider

	objectWriter transformer.ChunkedObjectWriter

	currOffset uint64

	originalPayloadSize uint64

	hdr *objectSDK.Object

	attrPatchAlreadyApplied bool

	readerBuffSize int
}

const (
	DefaultReaderBufferSize = 64 * 1024
)

// Params is parameters to initialize patcher.
type Params struct {
	// Original object header.
	Header *objectSDK.Object

	// Range provider.
	RangeProvider RangeProvider

	// ObjectWriter is the writer that writes the patched object.
	ObjectWriter transformer.ChunkedObjectWriter

	// The size of the buffer used by the original payload range reader.
	// If it's set to <=0, then `DefaultReaderBufferSize` is used.
	ReaderBufferSize int
}

func New(prm Params) PatchApplier {
	readerBufferSize := prm.ReaderBufferSize
	if readerBufferSize <= 0 {
		readerBufferSize = DefaultReaderBufferSize
	}

	return &patcher{
		rangeProvider: prm.RangeProvider,

		objectWriter: prm.ObjectWriter,

		hdr: prm.Header,

		originalPayloadSize: prm.Header.PayloadSize(),

		readerBuffSize: readerBufferSize,
	}
}

func (p *patcher) ApplyAttributesPatch(ctx context.Context, newAttrs []objectSDK.Attribute, replaceAttrs bool) error {
	defer func() {
		p.attrPatchAlreadyApplied = true
	}()

	if p.attrPatchAlreadyApplied {
		return ErrAttrPatchAlreadyApplied
	}

	if replaceAttrs {
		p.hdr.SetAttributes(newAttrs...)
	} else if len(newAttrs) > 0 {
		mergedAttrs := mergeAttributes(newAttrs, p.hdr.Attributes())
		p.hdr.SetAttributes(mergedAttrs...)
	}

	if err := p.objectWriter.WriteHeader(ctx, p.hdr); err != nil {
		return fmt.Errorf("writer header: %w", err)
	}
	return nil
}

func (p *patcher) ApplyPayloadPatch(ctx context.Context, payloadPatch *objectSDK.PayloadPatch) error {
	if payloadPatch == nil {
		return ErrPayloadPatchIsNil
	}

	if payloadPatch.Range.GetOffset() < p.currOffset {
		return fmt.Errorf("%w: current = %d, previous = %d", ErrInvalidPatchOffsetOrder, payloadPatch.Range.GetOffset(), p.currOffset)
	}

	if payloadPatch.Range.GetOffset() > p.originalPayloadSize {
		return fmt.Errorf("%w: offset = %d, object size = %d", ErrOffsetExceedsSize, payloadPatch.Range.GetOffset(), p.originalPayloadSize)
	}

	var err error
	if p.currOffset, err = p.applyPatch(ctx, payloadPatch, p.currOffset); err != nil {
		return fmt.Errorf("apply patch: %w", err)
	}

	return nil
}

func (p *patcher) Close(ctx context.Context) (PatchRes, error) {
	rng := new(objectSDK.Range)
	rng.SetOffset(p.currOffset)
	rng.SetLength(p.originalPayloadSize - p.currOffset)

	// copy remaining originial payload
	if err := p.copyRange(ctx, rng); err != nil {
		return PatchRes{}, fmt.Errorf("copy payload: %w", err)
	}

	aid, err := p.objectWriter.Close(ctx)
	if err != nil {
		return PatchRes{}, fmt.Errorf("close object writer: %w", err)
	}

	return PatchRes{
		AccessIdentifiers: aid,
	}, nil
}

func (p *patcher) copyRange(ctx context.Context, rng *objectSDK.Range) error {
	rdr := p.rangeProvider.GetRange(ctx, rng)
	for {
		buffOrigPayload := make([]byte, p.readerBuffSize)
		n, readErr := rdr.Read(buffOrigPayload)
		if readErr != nil {
			if readErr != io.EOF {
				return fmt.Errorf("read: %w", readErr)
			}
		}
		_, wrErr := p.objectWriter.Write(ctx, buffOrigPayload[:n])
		if wrErr != nil {
			return fmt.Errorf("write: %w", wrErr)
		}
		if readErr == io.EOF {
			break
		}
	}
	return nil
}

func (p *patcher) applyPatch(ctx context.Context, payloadPatch *objectSDK.PayloadPatch, offset uint64) (newOffset uint64, err error) {
	newOffset = offset

	// write the original payload chunk before the start of the patch
	if payloadPatch.Range.GetOffset() > offset {
		rng := new(objectSDK.Range)
		rng.SetOffset(offset)
		rng.SetLength(payloadPatch.Range.GetOffset() - offset)

		if err = p.copyRange(ctx, rng); err != nil {
			err = fmt.Errorf("copy payload: %w", err)
			return
		}

		newOffset = payloadPatch.Range.GetOffset()
	}

	// apply patch
	if _, err = p.objectWriter.Write(ctx, payloadPatch.Chunk); err != nil {
		return
	}

	if payloadPatch.Range.GetLength() > 0 {
		newOffset += payloadPatch.Range.GetLength()
	}

	return
}

func mergeAttributes(newAttrs, oldAttrs []objectSDK.Attribute) []objectSDK.Attribute {
	attrMap := make(map[string]string, len(newAttrs))

	for _, attr := range newAttrs {
		attrMap[attr.Key()] = attr.Value()
	}

	for i := range oldAttrs {
		newVal, ok := attrMap[oldAttrs[i].Key()]
		if !ok {
			continue
		}
		oldAttrs[i].SetValue(newVal)
		delete(attrMap, oldAttrs[i].Key())
	}

	for _, newAttr := range newAttrs {
		if _, ok := attrMap[newAttr.Key()]; ok {
			oldAttrs = append(oldAttrs, newAttr)
		}
	}

	return oldAttrs
}