frostfs-api-go-pogpp/util/protogen/internalgengo/file.go
Evgenii Stratonikov 866db105ed [#107] protogen: Unify oneof getters with default protoc plugin
Signed-off-by: Evgenii Stratonikov <e.stratonikov@yadro.com>
2024-08-26 14:36:19 +03:00

245 lines
6.9 KiB
Go

package internalgengo
import (
"fmt"
"sort"
"strconv"
"strings"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
)
var (
strconvPackage = protogen.GoImportPath("strconv")
fmtPackage = protogen.GoImportPath("fmt")
jsonPackage = protogen.GoImportPath("encoding/json")
easyprotoPackage = protogen.GoImportPath("github.com/VictoriaMetrics/easyproto")
mpPackage = protogen.GoImportPath("git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/util/pool")
protoPackage = protogen.GoImportPath("git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/util/proto")
encodingPackage = protogen.GoImportPath("git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/util/proto/encoding")
mp = mpPackage.Ident("MarshalerPool")
)
// GenerateFile generates a *.pb.go file enforcing field-order serialization.
func GenerateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile {
filename := file.GeneratedFilenamePrefix + "_frostfs.pb.go"
g := gen.NewGeneratedFile(filename, file.GoImportPath)
g.P("// Code generated by protoc-gen-go-frostfs. DO NOT EDIT.")
g.P()
g.P("package ", file.GoPackageName)
g.P()
g.Import(encodingPackage)
// Doesn't work for multiple files in a single package, use external pool.
// g.P("var mp ", easyprotoPackage.Ident("MarshalerPool"))
for _, e := range file.Enums {
emitEnum(g, e)
}
for _, msg := range file.Messages {
emitEasyProto(g, msg)
}
return g
}
func emitEnum(g *protogen.GeneratedFile, e *protogen.Enum) {
g.P("type " + e.GoIdent.GoName + " int32")
g.P("const (")
for _, ev := range e.Values {
g.P(ev.GoIdent.GoName, " ", e.GoIdent.GoName, " = ", ev.Desc.Number())
}
g.P(")")
g.P("var (")
g.P(e.GoIdent.GoName+"_name", " = map[int32]string{")
for _, value := range e.Values {
g.P(value.Desc.Number(), ": ", strconv.Quote(string(value.Desc.Name())), ",")
}
g.P("}")
g.P(e.GoIdent.GoName+"_value", " = map[string]int32{")
for _, value := range e.Values {
g.P(strconv.Quote(string(value.Desc.Name())), ": ", value.Desc.Number(), ",")
}
g.P("}")
g.P(")")
g.P()
g.P("func (x ", e.GoIdent.GoName, ") String() string {")
g.P("if v, ok := ", e.GoIdent.GoName+"_name[int32(x)]; ok {")
g.P("return v")
g.P("}")
g.P("return ", strconvPackage.Ident("FormatInt"), "(int64(x), 10)")
g.P("}")
g.P("func (x *", e.GoIdent.GoName, ") FromString(s string) bool {")
g.P("if v, ok := ", e.GoIdent.GoName+"_value[s]; ok {")
g.P("*x = ", e.GoIdent.GoName, "(v)")
g.P("return true")
g.P("}")
g.P("return false")
g.P("}")
}
func emitEasyProto(g *protogen.GeneratedFile, msg *protogen.Message) {
for _, e := range msg.Enums {
emitEnum(g, e)
}
for _, m := range msg.Messages {
emitEasyProto(g, m)
}
g.P("type " + msg.GoIdent.GoName + " struct {")
emitMessageFields(g, msg)
g.P("}")
g.P("var (")
g.P("_ ", encodingPackage.Ident("ProtoMarshaler"), " = (*", msg.GoIdent.GoName, ")(nil)")
g.P("_ ", encodingPackage.Ident("ProtoUnmarshaler"), " = (*", msg.GoIdent.GoName, ")(nil)")
g.P("_ ", jsonPackage.Ident("Marshaler"), " = (*", msg.GoIdent.GoName, ")(nil)")
g.P("_ ", jsonPackage.Ident("Unmarshaler"), " = (*", msg.GoIdent.GoName, ")(nil)")
g.P(")")
emitStableSize(g, msg)
if strings.HasSuffix(msg.GoIdent.GoName, "Request") || strings.HasSuffix(msg.GoIdent.GoName, "Response") {
emitSignatureMethods(g, msg)
}
emitProtoMethods(g, msg)
emitGettersSetters(g, msg)
emitJSONMethods(g, msg)
for _, f := range msg.Fields {
if isFirstOneof(f) {
genOneof(g, f)
}
}
}
func isFirstOneof(f *protogen.Field) bool {
return f.Oneof != nil && f == f.Oneof.Fields[0]
}
func emitOneofGettersSetters(g *protogen.GeneratedFile, msg *protogen.Message, ff *protogen.Field) {
// For some reason protoc generates different code for oneof message/non-message fields:
// 1. For message we have 2 level struct wrapping and setters use inner type.
// 2. For other types we also have 2 level wrapping, but setters use outer type.
ft := fieldType(g, ff)
g.P("func (x *", msg.GoIdent.GoName, ") Get", ff.GoName, "() ", ft, " {")
g.P("if xx, ok := x.Get", ff.Oneof.GoName, "().(*", ff.GoIdent, "); ok { return xx.", ff.GoName, " }")
g.P("return ", fieldDefaultValue(ff))
g.P("}")
if ff.Desc.Kind() == protoreflect.MessageKind {
g.P("func (x *", msg.GoIdent.GoName, ") Set", ff.GoName, "(v ", ft, ") {")
g.P("x.", ff.Oneof.GoName, " = &", ff.GoIdent, "{", ff.GoName, ": v}")
g.P("}")
} else {
g.P("func (x *", msg.GoIdent.GoName, ") Set", ff.GoName, "(v *", ff.GoIdent, ") {")
g.P("x.", ff.Oneof.GoName, " = v")
g.P("}")
ft := fieldType(g, ff)
emitGetterSetter(g, ff.GoIdent.GoName, ff.GoName, ft.String(), fieldDefaultValue(ff))
}
}
func emitGettersSetters(g *protogen.GeneratedFile, msg *protogen.Message) {
for _, f := range msg.Fields {
if f.Oneof != nil {
if f.Oneof.Fields[0] == f {
emitGetterSetter(g, msg.GoIdent.GoName, f.Oneof.GoName, oneOfDescriptor(f.Oneof), "nil")
for _, ff := range f.Oneof.Fields {
emitOneofGettersSetters(g, msg, ff)
}
}
continue
}
ft := fieldType(g, f)
emitGetterSetter(g, msg.GoIdent.GoName, f.GoName, ft.String(), fieldDefaultValue(f))
}
}
func emitMessageFields(g *protogen.GeneratedFile, msg *protogen.Message) {
for _, field := range msg.Fields {
genMessageField(g, field)
}
}
func genMessageField(g *protogen.GeneratedFile, field *protogen.Field) {
if field.Oneof != nil {
if field.Oneof.Fields[0] == field {
g.P(field.Oneof.GoName, " ", oneOfDescriptor(field.Oneof))
}
return
}
typ := fieldType(g, field)
g.P(field.GoName, " ", typ, fmt.Sprintf(" `json:%q`", fieldJSONName(field)))
}
func oneOfDescriptor(oneof *protogen.Oneof) string {
return "is" + oneof.GoIdent.GoName
}
func genOneof(g *protogen.GeneratedFile, field *protogen.Field) {
ifName := oneOfDescriptor(field.Oneof)
g.P("type ", ifName, " interface {")
g.P(ifName, "()")
g.P("}")
g.P()
for _, field := range field.Oneof.Fields {
g.P("type ", field.GoIdent, " struct {")
ft := fieldType(g, field)
g.P(field.GoName, " ", ft)
g.P("}")
g.P()
}
for _, field := range field.Oneof.Fields {
g.P("func (*", field.GoIdent, ") ", ifName, "() {}")
g.P()
}
}
func fieldDefaultValue(field *protogen.Field) string {
if field.Desc.Cardinality() == protoreflect.Repeated {
return "nil"
}
switch field.Desc.Kind() {
case protoreflect.MessageKind, protoreflect.BytesKind:
return "nil"
case protoreflect.BoolKind:
return "false"
case protoreflect.StringKind:
return `""`
default:
return "0"
}
}
func castFieldName(f *protogen.Field) string {
if f.Oneof != nil {
return "x." + f.Oneof.GoName
}
name := "x." + f.GoName
if f.Desc.Kind() != protoreflect.EnumKind {
return name
}
return "int32(" + name + ")"
}
func sortFields(fs []*protogen.Field) []*protogen.Field {
res := make([]*protogen.Field, len(fs))
copy(res, fs)
sort.Slice(res, func(i, j int) bool {
return res[i].Desc.Number() < res[j].Desc.Number()
})
return res
}