package internalgengo

import (
	"google.golang.org/protobuf/compiler/protogen"
	"google.golang.org/protobuf/reflect/protoreflect"
)

func emitProtoMethods(g *protogen.GeneratedFile, msg *protogen.Message) {
	emitMarshalProtobuf(g, msg)
	emitUnmarshalProtobuf(g, msg)
}

func emitUnmarshalProtobuf(g *protogen.GeneratedFile, msg *protogen.Message) {
	g.P("// UnmarshalProtobuf implements the encoding.ProtoUnmarshaler interface.")
	g.P("func (x *", msg.GoIdent.GoName, ") UnmarshalProtobuf(src []byte) (err error) {")
	g.P("var fc ", easyprotoPackage.Ident("FieldContext"))
	g.P("for len(src) > 0 {")
	{
		g.P("src, err = fc.NextField(src)")
		g.P("if err != nil { return ", fmtPackage.Ident("Errorf"), `("cannot read next field in %s", "`, msg.GoIdent.GoName, `")}`)
		g.P("switch fc.FieldNum {")
		{
			for _, f := range msg.Fields {
				g.P("case ", f.Desc.Number(), ":", " // ", f.GoName)
				emitFieldUnmarshal(g, f)
			}
		}
		g.P("}")
	}
	g.P("}")

	g.P("return nil")
	g.P("}")
}

func emitFieldUnmarshal(g *protogen.GeneratedFile, f *protogen.Field) {
	name := castFieldName(f)
	if f.Desc.Kind() == protoreflect.MessageKind {
		g.P("data, ok := fc.MessageData()")
		g.P(`if !ok { return fmt.Errorf("cannot unmarshal field %s", "`, f.GoName, `") }`)
		if f.Desc.IsList() {
			g.P(name, " = append(", name, ", ", fieldType(g, f)[2:], "{})")
			g.P("ff := &", name, "[len(", name, ")-1]")
			name = "ff"
		} else if f.Oneof != nil {
			const tmp = "oneofField"
			g.P(tmp, " := &", f.GoIdent, "{", f.GoName, ": ", "new(", fieldType(g, f)[1:], ")}")
			defer g.P(name, " = ", tmp)

			name = tmp + "." + f.GoName
		} else {
			g.P(name, " = new(", fieldType(g, f)[1:], ")")
		}

		g.P(`if err := `, name, `.UnmarshalProtobuf(data); err != nil { return fmt.Errorf("unmarshal: %w", err)}`)
		return
	}

	getter, _ := easyprotoKindInfo(f.Desc.Kind())

	if f.Desc.IsList() && (f.Desc.Kind() == protoreflect.BytesKind || f.Desc.Kind() == protoreflect.StringKind || f.Desc.Kind() == protoreflect.Uint64Kind && !f.Desc.IsPacked()) {
		g.P("data, ok := fc.", getter, "()")
		g.P(`if !ok { return fmt.Errorf("cannot unmarshal field %s", "`, f.GoName, `") }`)
		g.P(name, " = append(", name, ", data)")
		return
	}

	if f.Desc.IsList() {
		g.P("data, ok := fc.Unpack", getter, "s(nil)")
	} else {
		g.P("data, ok := fc.", getter, "()")
	}

	g.P(`if !ok { return fmt.Errorf("cannot unmarshal field %s", "`, f.GoName, `") }`)
	value := "data"
	if f.Desc.Kind() == protoreflect.EnumKind {
		value = fieldType(g, f).String() + "(data)"
	}

	if f.Oneof == nil {
		g.P("x.", f.GoName, " = ", value)
	} else {
		g.P("x.", f.Oneof.GoName, " = &", f.GoIdent, "{", f.GoName, ": data}")
	}
}

func emitMarshalProtobuf(g *protogen.GeneratedFile, msg *protogen.Message) {
	g.P("// MarshalProtobuf implements the encoding.ProtoMarshaler interface.")
	g.P("func (x *", msg.GoIdent.GoName, ") MarshalProtobuf(dst []byte) []byte {")
	g.P("m := ", mp, ".Get()")
	g.P("defer ", mp, ".Put(m)")
	g.P("x.EmitProtobuf(m.MessageMarshaler())")
	g.P("dst = m.Marshal(dst)")
	g.P("return dst")
	g.P("}\n")

	g.P("func (x *", msg.GoIdent.GoName, ") EmitProtobuf(mm *", easyprotoPackage.Ident("MessageMarshaler"), ") {")
	if len(msg.Fields) != 0 {
		fs := sortFields(msg.Fields)

		g.P("if x == nil { return }")
		for _, f := range fs {
			emitFieldMarshal(g, f)
		}
	}
	g.P("}")
}

func emitMarshalOneof(g *protogen.GeneratedFile, f *protogen.Field) {
	name := "x." + f.Oneof.GoName
	g.P("if inner, ok := ", name, ".(*", f.GoIdent.GoName, "); ok {")
	defer g.P("}")
	emitMarshalRaw(g, f, "inner."+f.GoName)
}

// easyprotoKindInfo returns string name for kind, used in easyproto methods.
// The second return value is a condition to test for the default value of kind.
func easyprotoKindInfo(kind protoreflect.Kind) (string, func(string) string) {
	switch kind {
	case protoreflect.BoolKind:
		return "Bool", identity
	case protoreflect.EnumKind:
		return "Int32", notZero
	case protoreflect.Int32Kind:
		return "Int32", notZero
	case protoreflect.Sint32Kind:
		return "Sint32", notZero
	case protoreflect.Uint32Kind:
		return "Uint32", notZero
	case protoreflect.Int64Kind:
		return "Int64", notZero
	case protoreflect.Sint64Kind:
		return "Sint64", notZero
	case protoreflect.Uint64Kind:
		return "Uint64", notZero
	case protoreflect.Sfixed32Kind:
		return "Sfixed32", notZero
	case protoreflect.Fixed32Kind:
		return "Fixed32", notZero
	case protoreflect.FloatKind:
		return "Float", notZero
	case protoreflect.Sfixed64Kind:
		return "Sfixed64", notZero
	case protoreflect.Fixed64Kind:
		return "Fixed64", notZero
	case protoreflect.DoubleKind:
		return "Double", notZero
	case protoreflect.StringKind:
		return "String", notEmpty
	case protoreflect.BytesKind:
		return "Bytes", notEmpty
	case protoreflect.GroupKind:
		panic("unimplemented")
	default:
		panic("unreachable")
	}
}

func emitFieldMarshal(g *protogen.GeneratedFile, f *protogen.Field) {
	if f.Oneof != nil {
		emitMarshalOneof(g, f)
		return
	}

	emitMarshalRaw(g, f, castFieldName(f))
}

func emitMarshalRaw(g *protogen.GeneratedFile, f *protogen.Field, name string) {
	if f.Desc.Kind() == protoreflect.MessageKind {
		if f.Desc.IsList() {
			g.P("for i := range ", name, " {")
			defer g.P("}")

			name += "[i]"
		} else {
			g.P("if ", notNil(name), " {")
			defer g.P("}")
		}

		g.P(name, ".EmitProtobuf(mm.AppendMessage(", f.Desc.Number(), "))")
		return
	}

	method, cond := easyprotoKindInfo(f.Desc.Kind())
	method = "Append" + method
	if f.Desc.IsList() && !f.Desc.IsPacked() {
		g.P("for j := range ", name, " {")
		g.P("mm.", method, "(", f.Desc.Number(), ", ", name, "[j])")
		g.P("}")
		return
	}

	if f.Desc.IsList() {
		method += "s"
		g.P("if ", notEmpty(name), "{")
	} else {
		g.P("if ", cond(name), " {")
	}

	g.P("mm.", method, "(", f.Desc.Number(), ", ", name, ")")
	g.P("}")
}