frostfs-api-go/util/protogen/main.go

262 lines
8.5 KiB
Go
Raw Normal View History

package main
import (
"sort"
"strings"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
)
var (
protowirePackage = protogen.GoImportPath("google.golang.org/protobuf/encoding/protowire")
binaryPackage = protogen.GoImportPath("encoding/binary")
)
func main() {
protogen.Options{}.Run(func(gen *protogen.Plugin) error {
for _, f := range gen.Files {
//if !f.Generate {
// continue
//}
imp := string(f.GoImportPath)
if strings.HasSuffix(imp, "/tree") || strings.HasSuffix(imp, "/control") {
generateFile(gen, f)
}
}
return nil
})
}
// generateFile generates a *.pb.go file enforcing field-order serialization.
func generateFile(gen *protogen.Plugin, file *protogen.File) *protogen.GeneratedFile {
filename := file.GeneratedFilenamePrefix + "_frostfs.pb.go"
g := gen.NewGeneratedFile(filename, file.GoImportPath)
g.P("// Code generated by protoc-gen-go-frostfs. DO NOT EDIT.")
g.P()
g.P("package ", file.GoPackageName)
g.P()
g.P(`import "git.frostfs.info/TrueCloudLab/frostfs-api-go/v2/util/proto"`)
//for _, e := range file.Enums {
// g.P("type " + e.GoIdent.GoName + " int32")
// g.P("const (")
// for _, ev := range e.Values {
// g.P(ev.GoIdent.GoName, " = ", ev.Desc.Number())
// }
// g.P(")")
//}
for _, msg := range file.Messages {
emitMessage(g, msg)
}
return g
}
func emitMessage(g *protogen.GeneratedFile, msg *protogen.Message) {
for _, inner := range msg.Messages {
emitMessage(g, inner)
}
fs := sortFields(msg.Fields)
// StableSize implementation.
g.P("// StableSize returns the size of x in protobuf format.")
g.P("//")
g.P("// Structures with the same field values have the same binary size.")
g.P("func (x *", msg.GoIdent.GoName, ") StableSize() (size int) {")
g.P("if x == nil { return 0 }")
if len(fs) != 0 {
for _, f := range fs {
if f.Desc.IsList() && marshalers[f.Desc.Kind()].RepeatedDouble && !(f.Desc.Kind() == protoreflect.Uint64Kind && !f.Desc.IsPacked()) {
g.P("var n int")
break
}
}
for _, f := range fs {
emitFieldSize(g, f)
}
}
g.P("return size")
g.P("}\n")
// StableMarshal implementation.
g.P("// StableMarshal marshals x in protobuf binary format with stable field order.")
g.P("//")
g.P("// If buffer length is less than x.StableSize(), new buffer is allocated.")
g.P("//")
g.P("// Returns any error encountered which did not allow writing the data completely.")
g.P("// Otherwise, returns the buffer in which the data is written.")
g.P("//")
g.P("// Structures with the same field values have the same binary format.")
g.P("func (x *", msg.GoIdent.GoName, ") StableMarshal(buf []byte) []byte {")
if len(fs) != 0 {
g.P("if x == nil { return []byte{} }")
g.P("if buf == nil { buf = make([]byte, x.StableSize()) }")
g.P("var offset int")
for _, f := range fs {
emitFieldMarshal(g, f)
}
}
g.P("return buf")
g.P("}\n")
if strings.HasSuffix(msg.GoIdent.GoName, "Request") || strings.HasSuffix(msg.GoIdent.GoName, "Response") {
// SignedDataSize implementation (only for requests and responses).
g.P("// ReadSignedData fills buf with signed data of x.")
g.P("// If buffer length is less than x.SignedDataSize(), new buffer is allocated.")
g.P("//")
g.P("// Returns any error encountered which did not allow writing the data completely.")
g.P("// Otherwise, returns the buffer in which the data is written.")
g.P("//")
g.P("// Structures with the same field values have the same signed data.")
g.P("func (x *", msg.GoIdent.GoName, ") SignedDataSize() int {")
g.P("return x.GetBody().StableSize()")
g.P("}\n")
// ReadSignedData implementation (only for requests and responses).
g.P("// SignedDataSize returns size of the request signed data in bytes.")
g.P("//")
g.P("// Structures with the same field values have the same signed data size.")
g.P("func (x *", msg.GoIdent.GoName, ") ReadSignedData(buf []byte) ([]byte, error) {")
g.P("return x.GetBody().StableMarshal(buf), nil")
g.P("}\n")
// Signature setters and getters.
g.P("func (x *", msg.GoIdent.GoName, ") SetSignature(sig *Signature) {")
g.P("x.Signature = sig")
g.P("}\n")
}
}
func emitFieldSize(g *protogen.GeneratedFile, f *protogen.Field) {
m := marshalers[f.Desc.Kind()]
if m.Prefix == "" {
g.P("// FIXME missing field marshaler: ", f.GoName, " of type ", f.Desc.Kind().String())
g.P(`panic("unimplemented")`)
return
}
name := castFieldName(f)
if f.Oneof != nil {
name = "x." + f.Oneof.GoName
g.P("if inner, ok := ", name, ".(*", f.GoIdent.GoName, "); ok {")
defer g.P("}")
name = "inner." + f.GoName
}
switch {
case f.Desc.IsList() && (f.Desc.Kind() == protoreflect.MessageKind || f.Desc.Kind() == protoreflect.Uint64Kind && !f.Desc.IsPacked()):
g.P("for i := range ", name, "{")
if f.Desc.Kind() == protoreflect.MessageKind {
g.P("size += proto.NestedStructureSize(", f.Desc.Number(), ", ", name, "[i])")
} else {
if f.Desc.Kind() != protoreflect.Uint64Kind {
panic("only uint64 unpacked primitive is supported")
}
g.P("size += ", protowirePackage.Ident("SizeGroup"), "(",
protowirePackage.Ident("Number"), "(", f.Desc.Number(), "), ",
protowirePackage.Ident("SizeVarint"), "(", name, "[i]))")
}
g.P("}")
case f.Desc.IsList():
if m.RepeatedDouble {
g.P("n, _ = proto.Repeated", m.Prefix, "Size(", f.Desc.Number(), ", ", name, ")")
g.P("size += n")
} else {
g.P("size += proto.Repeated", m.Prefix, "Size(", f.Desc.Number(), ", ", name, ")")
}
default:
g.P("size += proto.", m.Prefix, "Size(", f.Desc.Number(), ", ", name, ")")
}
}
func emitFieldMarshal(g *protogen.GeneratedFile, f *protogen.Field) {
m := marshalers[f.Desc.Kind()]
if m.Prefix == "" {
g.P("// FIXME missing field marshaler: ", f.GoName, " of type ", f.Desc.Kind().String())
g.P(`panic("unimplemented")`)
return
}
name := castFieldName(f)
if f.Oneof != nil {
name = "x." + f.Oneof.GoName
g.P("if inner, ok := ", name, ".(*", f.GoIdent.GoName, "); ok {")
defer g.P("}")
name = "inner." + f.GoName
}
prefix := m.Prefix
if f.Desc.IsList() {
prefix = "Repeated" + m.Prefix
}
switch {
case f.Desc.IsList() && (f.Desc.Kind() == protoreflect.MessageKind || f.Desc.Kind() == protoreflect.Uint64Kind && !f.Desc.IsPacked()):
g.P("for i := range ", name, "{")
if f.Desc.Kind() == protoreflect.MessageKind {
g.P("offset += proto.NestedStructureMarshal(", f.Desc.Number(), ", buf[offset:], ", name, "[i])")
} else {
if f.Desc.Kind() != protoreflect.Uint64Kind {
panic("only uint64 unpacked primitive is supported")
}
g.P("{")
g.P("prefix := ", protowirePackage.Ident("EncodeTag"), "(",
protowirePackage.Ident("Number"), "(", f.Desc.Number(), "), ",
protowirePackage.Ident("VarintType"), ")")
g.P("offset += ", binaryPackage.Ident("PutUvarint"), "(buf[offset:], uint64(prefix))")
g.P("offset += ", binaryPackage.Ident("PutUvarint"), "(buf[offset:], ", name, "[i])")
g.P("}")
}
g.P("}")
case f.Desc.IsList():
g.P("offset += proto.Repeated", m.Prefix, "Marshal(", f.Desc.Number(), ", buf[offset:], ", name, ")")
default:
g.P("offset += proto.", prefix, "Marshal(", f.Desc.Number(), ", buf[offset:], ", name, ")")
}
}
func castFieldName(f *protogen.Field) string {
name := "x." + f.GoName
if f.Desc.Kind() != protoreflect.EnumKind {
return name
}
return "int32(" + name + ")"
}
type marshalerDesc struct {
Prefix string
RepeatedDouble bool
}
// Unused kinds are commented.
var marshalers = map[protoreflect.Kind]marshalerDesc{
protoreflect.BoolKind: {Prefix: "Bool"},
protoreflect.EnumKind: {Prefix: "Enum"},
// protoreflect.Int32Kind: "",
// protoreflect.Sint32Kind: "",
protoreflect.Uint32Kind: {Prefix: "UInt32", RepeatedDouble: true},
protoreflect.Int64Kind: {Prefix: "Int64", RepeatedDouble: true},
// protoreflect.Sint64Kind: "",
protoreflect.Uint64Kind: {Prefix: "UInt64", RepeatedDouble: true},
// protoreflect.Sfixed32Kind: "",
protoreflect.Fixed32Kind: {Prefix: "Fixed32", RepeatedDouble: true},
// protoreflect.FloatKind: "",
// protoreflect.Sfixed64Kind: "",
protoreflect.Fixed64Kind: {Prefix: "Fixed64", RepeatedDouble: true},
protoreflect.DoubleKind: {Prefix: "Float64"},
protoreflect.StringKind: {Prefix: "String"},
protoreflect.BytesKind: {Prefix: "Bytes"},
protoreflect.MessageKind: {Prefix: "NestedStructure"},
// protoreflect.GroupKind: "",
}
func sortFields(fs []*protogen.Field) []*protogen.Field {
res := make([]*protogen.Field, len(fs))
copy(res, fs)
sort.Slice(res, func(i, j int) bool {
return res[i].Desc.Number() < res[j].Desc.Number()
})
return res
}