package policy import ( "errors" "fmt" "strconv" "strings" "github.com/antlr/antlr4/runtime/Go/antlr" "github.com/nspcc-dev/neofs-api-go/pkg/netmap" "github.com/nspcc-dev/neofs-sdk-go/pkg/policy/parser" ) var ( // ErrInvalidNumber is returned when a value of SELECT is 0. ErrInvalidNumber = errors.New("policy: expected positive integer") // ErrUnknownClause is returned when a statement(clause) in a query is unknown. ErrUnknownClause = errors.New("policy: unknown clause") // ErrUnknownOp is returned when an operation in a query is unknown. ErrUnknownOp = errors.New("policy: unknown operation") // ErrUnknownFilter is returned when a value of FROM in a query is unknown. ErrUnknownFilter = errors.New("policy: filter not found") // ErrUnknownSelector is returned when a value of IN is unknown. ErrUnknownSelector = errors.New("policy: selector not found") // ErrSyntaxError is returned for errors found by ANTLR parser. ErrSyntaxError = errors.New("policy: syntax error") ) type policyVisitor struct { errors []error parser.BaseQueryVisitor antlr.DefaultErrorListener } // Parse parses s into a placement policy. func Parse(s string) (*netmap.PlacementPolicy, error) { return parse(s) } func newPolicyVisitor() *policyVisitor { return &policyVisitor{} } func parse(s string) (*netmap.PlacementPolicy, error) { input := antlr.NewInputStream(s) lexer := parser.NewQueryLexer(input) stream := antlr.NewCommonTokenStream(lexer, 0) p := parser.NewQuery(stream) p.BuildParseTrees = true v := newPolicyVisitor() p.RemoveErrorListeners() p.AddErrorListener(v) pl := p.Policy().Accept(v) if len(v.errors) != 0 { return nil, v.errors[0] } if err := validatePolicy(pl.(*netmap.PlacementPolicy)); err != nil { return nil, err } return pl.(*netmap.PlacementPolicy), nil } func (p *policyVisitor) SyntaxError(recognizer antlr.Recognizer, offendingSymbol interface{}, line, column int, msg string, e antlr.RecognitionException) { p.reportError(fmt.Errorf("%w: line %d:%d %s", ErrSyntaxError, line, column, msg)) } func (p *policyVisitor) reportError(err error) interface{} { p.errors = append(p.errors, err) return nil } // VisitPolicy implements parser.QueryVisitor interface. func (p *policyVisitor) VisitPolicy(ctx *parser.PolicyContext) interface{} { if len(p.errors) != 0 { return nil } pl := new(netmap.PlacementPolicy) repStmts := ctx.AllRepStmt() rs := make([]*netmap.Replica, 0, len(repStmts)) for _, r := range repStmts { res, ok := r.Accept(p).(*netmap.Replica) if !ok { return nil } rs = append(rs, res) } pl.SetReplicas(rs...) if cbfStmt := ctx.CbfStmt(); cbfStmt != nil { cbf, ok := cbfStmt.(*parser.CbfStmtContext).Accept(p).(uint32) if !ok { return nil } pl.SetContainerBackupFactor(cbf) } selStmts := ctx.AllSelectStmt() ss := make([]*netmap.Selector, 0, len(selStmts)) for _, s := range selStmts { res, ok := s.Accept(p).(*netmap.Selector) if !ok { return nil } ss = append(ss, res) } pl.SetSelectors(ss...) filtStmts := ctx.AllFilterStmt() fs := make([]*netmap.Filter, 0, len(filtStmts)) for _, f := range filtStmts { fs = append(fs, f.Accept(p).(*netmap.Filter)) } pl.SetFilters(fs...) return pl } func (p *policyVisitor) VisitCbfStmt(ctx *parser.CbfStmtContext) interface{} { cbf, err := strconv.ParseUint(ctx.GetBackupFactor().GetText(), 10, 32) if err != nil { return p.reportError(ErrInvalidNumber) } return uint32(cbf) } // VisitRepStmt implements parser.QueryVisitor interface. func (p *policyVisitor) VisitRepStmt(ctx *parser.RepStmtContext) interface{} { num, err := strconv.ParseUint(ctx.GetCount().GetText(), 10, 32) if err != nil { return p.reportError(ErrInvalidNumber) } rs := new(netmap.Replica) rs.SetCount(uint32(num)) if sel := ctx.GetSelector(); sel != nil { rs.SetSelector(sel.GetText()) } return rs } // VisitSelectStmt implements parser.QueryVisitor interface. func (p *policyVisitor) VisitSelectStmt(ctx *parser.SelectStmtContext) interface{} { res, err := strconv.ParseUint(ctx.GetCount().GetText(), 10, 32) if err != nil { return p.reportError(ErrInvalidNumber) } s := new(netmap.Selector) s.SetCount(uint32(res)) if clStmt := ctx.Clause(); clStmt != nil { s.SetClause(clauseFromString(clStmt.GetText())) } if bStmt := ctx.GetBucket(); bStmt != nil { s.SetAttribute(ctx.GetBucket().GetText()) } s.SetFilter(ctx.GetFilter().GetText()) // either ident or wildcard if ctx.AS() != nil { s.SetName(ctx.GetName().GetText()) } return s } // VisitFilterStmt implements parser.QueryVisitor interface. func (p *policyVisitor) VisitFilterStmt(ctx *parser.FilterStmtContext) interface{} { f := p.VisitFilterExpr(ctx.GetExpr().(*parser.FilterExprContext)).(*netmap.Filter) f.SetName(ctx.GetName().GetText()) return f } func (p *policyVisitor) VisitFilterExpr(ctx *parser.FilterExprContext) interface{} { if eCtx := ctx.Expr(); eCtx != nil { return eCtx.Accept(p) } f := new(netmap.Filter) op := operationFromString(ctx.GetOp().GetText()) f.SetOperation(op) f1 := ctx.GetF1().Accept(p).(*netmap.Filter) f2 := ctx.GetF2().Accept(p).(*netmap.Filter) // Consider f1=(.. AND ..) AND f2. This can be merged because our AND operation // is of arbitrary arity. ANTLR generates left-associative parse-tree by default. if f1.Operation() == op { f.SetInnerFilters(append(f1.InnerFilters(), f2)...) return f } f.SetInnerFilters(f1, f2) return f } // VisitFilterKey implements parser.QueryVisitor interface. func (p *policyVisitor) VisitFilterKey(ctx *parser.FilterKeyContext) interface{} { if id := ctx.Ident(); id != nil { return id.GetText() } str := ctx.STRING().GetText() return str[1 : len(str)-1] } func (p *policyVisitor) VisitFilterValue(ctx *parser.FilterValueContext) interface{} { if id := ctx.Ident(); id != nil { return id.GetText() } if num := ctx.Number(); num != nil { return num.GetText() } str := ctx.STRING().GetText() return str[1 : len(str)-1] } // VisitExpr implements parser.QueryVisitor interface. func (p *policyVisitor) VisitExpr(ctx *parser.ExprContext) interface{} { f := new(netmap.Filter) if flt := ctx.GetFilter(); flt != nil { f.SetName(flt.GetText()) return f } key := ctx.GetKey().Accept(p) opStr := ctx.SIMPLE_OP().GetText() value := ctx.GetValue().Accept(p) f.SetKey(key.(string)) f.SetOperation(operationFromString(opStr)) f.SetValue(value.(string)) return f } // validatePolicy checks high-level constraints such as filter link in SELECT // being actually defined in FILTER section. func validatePolicy(p *netmap.PlacementPolicy) error { seenFilters := map[string]bool{} for _, f := range p.Filters() { seenFilters[f.Name()] = true } seenSelectors := map[string]bool{} for _, s := range p.Selectors() { if flt := s.Filter(); flt != netmap.MainFilterName && !seenFilters[flt] { return fmt.Errorf("%w: '%s'", ErrUnknownFilter, flt) } seenSelectors[s.Name()] = true } for _, r := range p.Replicas() { if sel := r.Selector(); sel != "" && !seenSelectors[sel] { return fmt.Errorf("%w: '%s'", ErrUnknownSelector, sel) } } return nil } func clauseFromString(s string) netmap.Clause { switch strings.ToUpper(s) { case "SAME": return netmap.ClauseSame case "DISTINCT": return netmap.ClauseDistinct default: // Such errors should be handled by ANTLR code thus this panic. panic(fmt.Errorf("BUG: invalid clause: %s", s)) } } func operationFromString(op string) netmap.Operation { switch strings.ToUpper(op) { case "AND": return netmap.OpAND case "OR": return netmap.OpOR case "EQ": return netmap.OpEQ case "NE": return netmap.OpNE case "GE": return netmap.OpGE case "GT": return netmap.OpGT case "LE": return netmap.OpLE case "LT": return netmap.OpLT default: // Such errors should be handled by ANTLR code thus this panic. panic(fmt.Errorf("BUG: invalid operation: %s", op)) } }