package internal_gengo import ( "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/reflect/protoreflect" ) func emitProtoMethods(g *protogen.GeneratedFile, msg *protogen.Message) { emitMarshalProtobuf(g, msg) emitUnmarshalProtobuf(g, msg) } func emitUnmarshalProtobuf(g *protogen.GeneratedFile, msg *protogen.Message) { g.P("// UnmarshalProtobuf implements the encoding.ProtoUnmarshaler interface.") g.P("func (x *", msg.GoIdent.GoName, ") UnmarshalProtobuf(src []byte) (err error) {") g.P("var fc ", easyprotoPackage.Ident("FieldContext")) g.P("for len(src) > 0 {") g.P("src, err = fc.NextField(src)") g.P("if err != nil { return ", fmtPackage.Ident("Errorf"), `("cannot read next field in %s", "`, msg.GoIdent.GoName, `")}`) g.P("switch fc.FieldNum {") for _, f := range msg.Fields { g.P("case ", f.Desc.Number(), ":", " // ", f.GoName) emitFieldUnmarshal(g, f) } g.P("}") g.P("}") // g.P("x.messageData = src") g.P("return nil") g.P("}") } func emitFieldUnmarshal(g *protogen.GeneratedFile, f *protogen.Field) { name := castFieldName(f) var getter string switch f.Desc.Kind() { case protoreflect.BoolKind: getter = "Bool" case protoreflect.EnumKind: getter = "Int32" case protoreflect.Int32Kind: getter = "Int32" case protoreflect.Sint32Kind: getter = "Sint32" case protoreflect.Uint32Kind: getter = "Uint32" case protoreflect.Int64Kind: getter = "Int64" case protoreflect.Sint64Kind: getter = "Sint64" case protoreflect.Uint64Kind: getter = "Uint64" case protoreflect.Sfixed32Kind: getter = "Sfixed32" case protoreflect.Fixed32Kind: getter = "Fixed32" case protoreflect.FloatKind: getter = "Float" case protoreflect.Sfixed64Kind: getter = "Sfixed64" case protoreflect.Fixed64Kind: getter = "Fixed64" case protoreflect.DoubleKind: getter = "Double" case protoreflect.StringKind: getter = "String" case protoreflect.BytesKind: getter = "Bytes" case protoreflect.MessageKind: g.P("data, ok := fc.MessageData()") g.P(`if !ok { return fmt.Errorf("cannot unmarshal field %s", "`, f.GoName, `") }`) if f.Desc.IsList() { g.P(name, " = append(", name, ", new(", fieldType(g, f)[3:], "))") g.P("ff := ", name, "[len(", name, ")-1]") name = "ff" } else if f.Oneof != nil { const tmp = "oneofField" g.P(tmp, " := &", f.GoIdent, "{", f.GoName, ": ", "new(", fieldType(g, f)[1:], ")}") defer g.P(name, " = ", tmp) name = tmp + "." + f.GoName } else { g.P(name, " = new(", fieldType(g, f)[1:], ")") } g.P(`if err := `, name, `.UnmarshalProtobuf(data); err != nil { return fmt.Errorf("unmarshal: %w", err)}`) return case protoreflect.GroupKind: panic("unimplemented") } if f.Desc.IsList() && (f.Desc.Kind() == protoreflect.BytesKind || f.Desc.Kind() == protoreflect.StringKind) { g.P("data, ok := fc.", getter, "()") g.P(`if !ok { return fmt.Errorf("cannot unmarshal field %s", "`, f.GoName, `") }`) g.P(name, " = append(", name, ", data)") return } if f.Desc.IsList() { g.P("data, ok := fc.Unpack", getter, "s(nil)") } else { g.P("data, ok := fc.", getter, "()") } g.P(`if !ok { return fmt.Errorf("cannot unmarshal field %s", "`, f.GoName, `") }`) value := "data" if f.Desc.Kind() == protoreflect.EnumKind { value = fieldType(g, f).String() + "(data)" } if f.Oneof == nil { g.P("x.", f.GoName, " = ", value) } else { g.P("x.", f.Oneof.GoName, " = &", f.GoIdent, "{", f.GoName, ": data}") } } func emitMarshalProtobuf(g *protogen.GeneratedFile, msg *protogen.Message) { g.P("// MarshalProtobuf implements the encoding.ProtoMarshaler interface.") g.P("func (x *", msg.GoIdent.GoName, ") MarshalProtobuf(dst []byte) []byte {") g.P("m := ", mp, ".Get()") g.P("defer ", mp, ".Put(m)") g.P("x.EmitProtobuf(m.MessageMarshaler())") g.P("dst = m.Marshal(dst)") g.P("return dst") g.P("}\n") g.P("func (x *", msg.GoIdent.GoName, ") EmitProtobuf(mm *", easyprotoPackage.Ident("MessageMarshaler"), ") {") if len(msg.Fields) != 0 { fs := sortFields(msg.Fields) g.P("if x == nil { return }") for _, f := range fs { emitFieldMarshal(g, f) } } g.P("}") } func emitFieldMarshal(g *protogen.GeneratedFile, f *protogen.Field) { name := castFieldName(f) if f.Oneof != nil { name = "x." + f.Oneof.GoName g.P("if inner, ok := ", name, ".(*", f.GoIdent.GoName, "); ok {") defer g.P("}") name = "inner." + f.GoName } cond := "!= 0" method := "" switch f.Desc.Kind() { case protoreflect.BoolKind: cond = "" method = "AppendBool" case protoreflect.EnumKind: method = "AppendInt32" case protoreflect.Int32Kind: method = "AppendInt32" case protoreflect.Sint32Kind: method = "AppendSint32" case protoreflect.Uint32Kind: method = "AppendUint32" case protoreflect.Int64Kind: method = "AppendInt64" case protoreflect.Sint64Kind: method = "AppendSint64" case protoreflect.Uint64Kind: method = "AppendUint64" case protoreflect.Sfixed32Kind: method = "AppendSfixed32" case protoreflect.Fixed32Kind: method = "AppendFixed32" case protoreflect.FloatKind: method = "AppendFloat" case protoreflect.Sfixed64Kind: method = "AppendSfixed64" case protoreflect.Fixed64Kind: method = "AppendFixed64" case protoreflect.DoubleKind: method = "AppendDouble" case protoreflect.StringKind: cond = `!= ""` method = "AppendString" case protoreflect.BytesKind: cond = `!= nil` method = "AppendBytes" case protoreflect.MessageKind: if f.Desc.IsList() { g.P("for i := range ", name, " {") defer g.P("}") name += "[i]" } g.P("if ", name, " != nil && ", name, ".StableSize() != 0 {") g.P(name, ".EmitProtobuf(mm.AppendMessage(", f.Desc.Number(), "))") g.P("}") return case protoreflect.GroupKind: panic("unimplemented") } if f.Desc.IsList() && (f.Desc.Kind() == protoreflect.BytesKind || f.Desc.Kind() == protoreflect.StringKind) { g.P("if ", name, " != nil {") g.P("for j := range ", name, " {") name += "[j]" if f.Desc.Kind() == protoreflect.BytesKind { g.P("mm.AppendBytes(", f.Desc.Number(), ", ", name, ")") } else { g.P("mm.AppendString(", f.Desc.Number(), ", ", name, ")") } g.P("}") g.P("}") return } if f.Desc.IsList() { method += "s" cond = "!= nil" } g.P("if ", name, " ", cond, " {") g.P("mm.", method, "(", f.Desc.Number(), ", ", name, ")") g.P("}") }