/*
This package contains help functions for stable marshaller. Their usage is
totally optional. One can implement fast stable marshaller without these
runtime function calls.
*/

package proto

import (
	"encoding/binary"
	"math"
	"math/bits"

	"google.golang.org/protobuf/encoding/protowire"
)

type (
	stableMarshaler interface {
		StableMarshal([]byte) []byte
		stableSizer
	}

	stableSizer interface {
		StableSize() int
	}

	setMarshalData[T any] interface {
		SetMarshalData([]byte)
		StableSize() int
		~*T
	}
)

func BytesMarshal(field int, buf, v []byte) int {
	return bytesMarshal(field, buf, v)
}

func BytesSize(field int, v []byte) int {
	return bytesSize(field, v)
}

func bytesMarshal[T ~[]byte | ~string](field int, buf []byte, v T) int {
	if len(v) == 0 {
		return 0
	}
	return bytesMarshalNoCheck(field, buf, v)
}

func bytesMarshalNoCheck[T ~[]byte | ~string](field int, buf []byte, v T) int {
	prefix := protowire.EncodeTag(protowire.Number(field), protowire.BytesType)

	// buf length check can prevent panic at PutUvarint, but it will make
	// marshaller a bit slower.
	i := binary.PutUvarint(buf, uint64(prefix))
	i += binary.PutUvarint(buf[i:], uint64(len(v)))
	i += copy(buf[i:], v)

	return i
}

func bytesSize[T ~[]byte | ~string](field int, v T) int {
	if len(v) == 0 {
		return 0
	}
	return bytesSizeNoCheck(field, v)
}

func bytesSizeNoCheck[T ~[]byte | ~string](field int, v T) int {
	return protowire.SizeGroup(protowire.Number(field), protowire.SizeBytes(len(v)))
}

func StringMarshal(field int, buf []byte, v string) int {
	return bytesMarshal(field, buf, v)
}

func StringSize(field int, v string) int {
	return bytesSize(field, v)
}

func BoolMarshal(field int, buf []byte, v bool) int {
	if !v {
		return 0
	}

	prefix := protowire.EncodeTag(protowire.Number(field), protowire.VarintType)

	// buf length check can prevent panic at PutUvarint, but it will make
	// marshaller a bit slower.
	i := binary.PutUvarint(buf, uint64(prefix))
	const boolTrueValue = 0x1
	buf[i] = boolTrueValue

	return i + 1
}

func BoolSize(field int, v bool) int {
	if !v {
		return 0
	}
	const boolLength = 1
	return protowire.SizeGroup(protowire.Number(field), boolLength)
}

func UInt64Marshal(field int, buf []byte, v uint64) int {
	if v == 0 {
		return 0
	}

	prefix := protowire.EncodeTag(protowire.Number(field), protowire.VarintType)

	// buf length check can prevent panic at PutUvarint, but it will make
	// marshaller a bit slower.
	i := binary.PutUvarint(buf, uint64(prefix))
	i += binary.PutUvarint(buf[i:], v)

	return i
}

func UInt64Size(field int, v uint64) int {
	if v == 0 {
		return 0
	}
	return protowire.SizeGroup(protowire.Number(field), protowire.SizeVarint(v))
}

func Int64Marshal(field int, buf []byte, v int64) int {
	return UInt64Marshal(field, buf, uint64(v))
}

func Int64Size(field int, v int64) int {
	return UInt64Size(field, uint64(v))
}

func UInt32Marshal(field int, buf []byte, v uint32) int {
	return UInt64Marshal(field, buf, uint64(v))
}

func UInt32Size(field int, v uint32) int {
	return UInt64Size(field, uint64(v))
}

func Int32Marshal(field int, buf []byte, v int32) int {
	return UInt64Marshal(field, buf, uint64(uint32(v)))
}

func Int32Size(field int, v int32) int {
	return UInt64Size(field, uint64(uint32(v)))
}

func EnumMarshal(field int, buf []byte, v int32) int {
	return UInt64Marshal(field, buf, uint64(uint32(v)))
}

func EnumSize(field int, v int32) int {
	return UInt64Size(field, uint64(uint32(v)))
}

func RepeatedBytesMarshal(field int, buf []byte, v [][]byte) int {
	var offset int

	for i := range v {
		offset += bytesMarshalNoCheck(field, buf[offset:], v[i])
	}

	return offset
}

func RepeatedBytesSize(field int, v [][]byte) (size int) {
	for i := range v {
		size += bytesSizeNoCheck(field, v[i])
	}

	return size
}

func RepeatedStringMarshal(field int, buf []byte, v []string) int {
	var offset int

	for i := range v {
		offset += bytesMarshalNoCheck(field, buf[offset:], v[i])
	}

	return offset
}

func RepeatedStringSize(field int, v []string) (size int) {
	for i := range v {
		size += bytesSizeNoCheck(field, v[i])
	}

	return size
}

func repeatedUIntSize[T ~uint64 | ~int64 | ~uint32 | ~int32](field int, v []T) (size, arraySize int) {
	if len(v) == 0 {
		return 0, 0
	}

	for i := range v {
		arraySize += protowire.SizeVarint(uint64(v[i]))
	}

	size = protowire.SizeGroup(protowire.Number(field), protowire.SizeBytes(arraySize))

	return
}

func repeatedUIntMarshal[T ~uint64 | ~int64 | ~uint32 | ~int32](field int, buf []byte, v []T) int {
	if len(v) == 0 {
		return 0
	}

	prefix := protowire.EncodeTag(protowire.Number(field), protowire.BytesType)
	offset := binary.PutUvarint(buf, uint64(prefix))

	_, arrSize := repeatedUIntSize(field, v)
	offset += binary.PutUvarint(buf[offset:], uint64(arrSize))
	for i := range v {
		offset += binary.PutUvarint(buf[offset:], uint64(v[i]))
	}

	return offset
}

func RepeatedUInt64Marshal(field int, buf []byte, v []uint64) int {
	return repeatedUIntMarshal(field, buf, v)
}

func RepeatedUInt64Size(field int, v []uint64) (size, arraySize int) {
	return repeatedUIntSize(field, v)
}

func RepeatedInt64Marshal(field int, buf []byte, v []int64) int {
	return repeatedUIntMarshal(field, buf, v)
}

func RepeatedInt64Size(field int, v []int64) (size, arraySize int) {
	return repeatedUIntSize(field, v)
}

func RepeatedUInt32Marshal(field int, buf []byte, v []uint32) int {
	return repeatedUIntMarshal(field, buf, v)
}

func RepeatedUInt32Size(field int, v []uint32) (size, arraySize int) {
	return repeatedUIntSize(field, v)
}

func RepeatedInt32Marshal(field int, buf []byte, v []int32) int {
	if len(v) == 0 {
		return 0
	}

	prefix := protowire.EncodeTag(protowire.Number(field), protowire.BytesType)
	offset := binary.PutUvarint(buf, uint64(prefix))
	_, arrSize := RepeatedInt32Size(field, v)
	offset += binary.PutUvarint(buf[offset:], uint64(arrSize))
	for i := range v {
		offset += binary.PutUvarint(buf[offset:], uint64(uint32(v[i])))
	}
	return offset
}

func RepeatedInt32Size(field int, v []int32) (size, arraySize int) {
	if len(v) == 0 {
		return 0, 0
	}
	for i := range v {
		arraySize += protowire.SizeVarint(uint64(uint32(v[i])))
	}
	return protowire.SizeGroup(protowire.Number(field), protowire.SizeBytes(arraySize)), arraySize
}

// VarUIntSize returns length of varint byte sequence for uint64 value 'x'.
func VarUIntSize(x uint64) int {
	return (bits.Len64(x|1) + 6) / 7
}

type ptrStableMarshaler[T any] interface {
	stableMarshaler
	~*T
}

type ptrStableSizer[T any] interface {
	stableSizer
	~*T
}

func NestedStructureMarshal[T any, M ptrStableMarshaler[T]](field int64, buf []byte, v M) int {
	if v == nil {
		return 0
	}

	return NestedStructureMarshalUnchecked(field, buf, v)
}

func NestedStructureMarshalUnchecked[T stableMarshaler](field int64, buf []byte, v T) int {
	n := v.StableSize()
	prefix := protowire.EncodeTag(protowire.Number(field), protowire.BytesType)
	offset := binary.PutUvarint(buf, prefix)
	offset += binary.PutUvarint(buf[offset:], uint64(n))
	v.StableMarshal(buf[offset:])

	return offset + n
}

// NestedStructureSetMarshalData calculates offset for field in parentData
// and calls SetMarshalData for nested structure.
//
// Returns marshalled data length of nested structure.
func NestedStructureSetMarshalData[T any, M setMarshalData[T]](field int64, parentData []byte, v M) int {
	if v == nil {
		return 0
	}

	if parentData == nil {
		v.SetMarshalData(nil)
		return 0
	}

	n := v.StableSize()
	buf := make([]byte, binary.MaxVarintLen64)
	prefix := protowire.EncodeTag(protowire.Number(field), protowire.BytesType)
	offset := binary.PutUvarint(buf, prefix)
	offset += binary.PutUvarint(buf, uint64(n))

	v.SetMarshalData(parentData[offset : offset+n])

	return offset + n
}

func NestedStructureSize[T any, M ptrStableSizer[T]](field int64, v M) (size int) {
	if v == nil {
		return 0
	}

	return NestedStructureSizeUnchecked(field, v)
}

func NestedStructureSizeUnchecked[T stableSizer](field int64, v T) int {
	n := v.StableSize()
	return protowire.SizeGroup(protowire.Number(field), protowire.SizeBytes(n))
}

func Fixed64Marshal(field int, buf []byte, v uint64) int {
	if v == 0 {
		return 0
	}

	prefix := protowire.EncodeTag(protowire.Number(field), protowire.Fixed64Type)

	// buf length check can prevent panic at PutUvarint, but it will make
	// marshaller a bit slower.
	i := binary.PutUvarint(buf, uint64(prefix))
	binary.LittleEndian.PutUint64(buf[i:], v)

	return i + 8
}

func Fixed64Size(fNum int, v uint64) int {
	if v == 0 {
		return 0
	}
	return protowire.SizeGroup(protowire.Number(fNum), protowire.SizeFixed64())
}

func Float64Marshal(field int, buf []byte, v float64) int {
	if v == 0 {
		return 0
	}

	prefix := protowire.EncodeTag(protowire.Number(field), protowire.Fixed64Type)

	i := binary.PutUvarint(buf, uint64(prefix))
	binary.LittleEndian.PutUint64(buf[i:], math.Float64bits(v))

	return i + 8
}

func Float64Size(fNum int, v float64) int {
	if v == 0 {
		return 0
	}
	return protowire.SizeGroup(protowire.Number(fNum), protowire.SizeFixed64())
}

// Fixed32Marshal encodes uint32 value to Protocol Buffers fixed32 field with specified number,
// and writes it to specified buffer. Returns number of bytes written.
//
// Panics if the buffer is undersized.
func Fixed32Marshal(field int, buf []byte, v uint32) int {
	if v == 0 {
		return 0
	}

	prefix := protowire.EncodeTag(protowire.Number(field), protowire.Fixed32Type)

	// buf length check can prevent panic at PutUvarint, but it will make
	// marshaller a bit slower.
	i := binary.PutUvarint(buf, uint64(prefix))
	binary.LittleEndian.PutUint32(buf[i:], v)

	return i + 4
}

// Fixed32Size returns number of bytes required to encode uint32 value to Protocol Buffers fixed32 field
// with specified number.
func Fixed32Size(fNum int, v uint32) int {
	if v == 0 {
		return 0
	}
	return protowire.SizeGroup(protowire.Number(fNum), protowire.SizeFixed32())
}