338 lines
8.4 KiB
Go
338 lines
8.4 KiB
Go
/*
|
|
* MinIO Cloud Storage, (C) 2019 MinIO, Inc.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
package sql
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/bcicen/jstream"
|
|
"github.com/minio/simdjson-go"
|
|
)
|
|
|
|
var (
|
|
errBadLimitSpecified = errors.New("Limit value must be a positive integer")
|
|
)
|
|
|
|
const (
|
|
baseTableName = "s3object"
|
|
)
|
|
|
|
// SelectStatement is the top level parsed and analyzed structure
|
|
type SelectStatement struct {
|
|
selectAST *Select
|
|
|
|
// Analysis result of the statement
|
|
selectQProp qProp
|
|
|
|
// Result of parsing the limit clause if one is present
|
|
// (otherwise -1)
|
|
limitValue int64
|
|
|
|
// Count of rows that have been output.
|
|
outputCount int64
|
|
}
|
|
|
|
// ParseSelectStatement - parses a select query from the given string
|
|
// and analyzes it.
|
|
func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
|
|
var selectAST Select
|
|
err = SQLParser.ParseString(s, &selectAST)
|
|
if err != nil {
|
|
err = errQueryParseFailure(err)
|
|
return
|
|
}
|
|
|
|
// Check if select is "SELECT s.* from S3Object s"
|
|
if !selectAST.Expression.All &&
|
|
len(selectAST.Expression.Expressions) == 1 &&
|
|
len(selectAST.Expression.Expressions[0].Expression.And) == 1 &&
|
|
len(selectAST.Expression.Expressions[0].Expression.And[0].Condition) == 1 &&
|
|
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand != nil &&
|
|
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left != nil &&
|
|
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left != nil &&
|
|
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary != nil &&
|
|
selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary.JPathExpr != nil {
|
|
if selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary.JPathExpr.String() == selectAST.From.As+".*" {
|
|
selectAST.Expression.All = true
|
|
}
|
|
}
|
|
stmt.selectAST = &selectAST
|
|
|
|
// Check the parsed limit value
|
|
stmt.limitValue, err = parseLimit(selectAST.Limit)
|
|
if err != nil {
|
|
err = errQueryAnalysisFailure(err)
|
|
return
|
|
}
|
|
|
|
// Analyze where clause
|
|
if selectAST.Where != nil {
|
|
whereQProp := selectAST.Where.analyze(&selectAST)
|
|
if whereQProp.err != nil {
|
|
err = errQueryAnalysisFailure(fmt.Errorf("Where clause error: %w", whereQProp.err))
|
|
return
|
|
}
|
|
|
|
if whereQProp.isAggregation {
|
|
err = errQueryAnalysisFailure(errors.New("WHERE clause cannot have an aggregation"))
|
|
return
|
|
}
|
|
}
|
|
|
|
// Validate table name
|
|
err = validateTableName(selectAST.From)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Analyze main select expression
|
|
stmt.selectQProp = selectAST.Expression.analyze(&selectAST)
|
|
err = stmt.selectQProp.err
|
|
if err != nil {
|
|
err = errQueryAnalysisFailure(err)
|
|
}
|
|
return
|
|
}
|
|
|
|
func validateTableName(from *TableExpression) error {
|
|
if strings.ToLower(from.Table.BaseKey.String()) != baseTableName {
|
|
return errBadTableName(errors.New("table name must be `s3object`"))
|
|
}
|
|
|
|
if len(from.Table.PathExpr) > 0 {
|
|
if !from.Table.PathExpr[0].ArrayWildcard {
|
|
return errBadTableName(errors.New("keypath table name is invalid - please check the service documentation"))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func parseLimit(v *LitValue) (int64, error) {
|
|
switch {
|
|
case v == nil:
|
|
return -1, nil
|
|
case v.Number == nil:
|
|
return -1, errBadLimitSpecified
|
|
default:
|
|
r := int64(*v.Number)
|
|
if r < 0 {
|
|
return -1, errBadLimitSpecified
|
|
}
|
|
return r, nil
|
|
}
|
|
}
|
|
|
|
// EvalFrom evaluates the From clause on the input record. It only
|
|
// applies to JSON input data format (currently).
|
|
func (e *SelectStatement) EvalFrom(format string, input Record) ([]*Record, error) {
|
|
if !e.selectAST.From.HasKeypath() {
|
|
return []*Record{&input}, nil
|
|
}
|
|
_, rawVal := input.Raw()
|
|
|
|
if format != "json" {
|
|
return nil, errDataSource(errors.New("path not supported"))
|
|
}
|
|
switch rec := rawVal.(type) {
|
|
case jstream.KVS:
|
|
txedRec, _, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], rec)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var kvs jstream.KVS
|
|
switch v := txedRec.(type) {
|
|
case jstream.KVS:
|
|
kvs = v
|
|
|
|
case []interface{}:
|
|
recs := make([]*Record, len(v))
|
|
for i, val := range v {
|
|
tmpRec := input.Clone(nil)
|
|
if err = tmpRec.Replace(val); err != nil {
|
|
return nil, err
|
|
}
|
|
recs[i] = &tmpRec
|
|
}
|
|
return recs, nil
|
|
|
|
default:
|
|
kvs = jstream.KVS{jstream.KV{Key: "_1", Value: v}}
|
|
}
|
|
|
|
if err = input.Replace(kvs); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return []*Record{&input}, nil
|
|
case simdjson.Object:
|
|
txedRec, _, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], rec)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch v := txedRec.(type) {
|
|
case simdjson.Object:
|
|
err := input.Replace(v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
case []interface{}:
|
|
recs := make([]*Record, len(v))
|
|
for i, val := range v {
|
|
tmpRec := input.Clone(nil)
|
|
if err = tmpRec.Replace(val); err != nil {
|
|
return nil, err
|
|
}
|
|
recs[i] = &tmpRec
|
|
}
|
|
return recs, nil
|
|
|
|
default:
|
|
input.Reset()
|
|
input, err = input.Set("_1", &Value{value: v})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return []*Record{&input}, nil
|
|
}
|
|
return nil, errDataSource(errors.New("unexpected non JSON input"))
|
|
}
|
|
|
|
// IsAggregated returns if the statement involves SQL aggregation
|
|
func (e *SelectStatement) IsAggregated() bool {
|
|
return e.selectQProp.isAggregation
|
|
}
|
|
|
|
// AggregateResult - returns the aggregated result after all input
|
|
// records have been processed. Applies only to aggregation queries.
|
|
func (e *SelectStatement) AggregateResult(output Record) error {
|
|
for i, expr := range e.selectAST.Expression.Expressions {
|
|
v, err := expr.evalNode(nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if expr.As != "" {
|
|
output, err = output.Set(expr.As, v)
|
|
} else {
|
|
output, err = output.Set(fmt.Sprintf("_%d", i+1), v)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) {
|
|
if e.selectAST.Where == nil {
|
|
return true, nil
|
|
}
|
|
value, err := e.selectAST.Where.evalNode(input)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
b, ok := value.ToBool()
|
|
if !ok {
|
|
err = fmt.Errorf("WHERE expression did not return bool")
|
|
return false, err
|
|
}
|
|
|
|
return b, nil
|
|
}
|
|
|
|
// AggregateRow - aggregates the input record. Applies only to
|
|
// aggregation queries.
|
|
func (e *SelectStatement) AggregateRow(input Record) error {
|
|
ok, err := e.isPassingWhereClause(input)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
for _, expr := range e.selectAST.Expression.Expressions {
|
|
err := expr.aggregateRow(input)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Eval - evaluates the Select statement for the given record. It
|
|
// applies only to non-aggregation queries.
|
|
// The function returns whether the statement passed the WHERE clause and should be outputted.
|
|
func (e *SelectStatement) Eval(input, output Record) (Record, error) {
|
|
ok, err := e.isPassingWhereClause(input)
|
|
if err != nil || !ok {
|
|
// Either error or row did not pass where clause
|
|
return nil, err
|
|
}
|
|
|
|
if e.selectAST.Expression.All {
|
|
// Return the input record for `SELECT * FROM
|
|
// .. WHERE ..`
|
|
|
|
// Update count of records output.
|
|
if e.limitValue > -1 {
|
|
e.outputCount++
|
|
}
|
|
return input.Clone(output), nil
|
|
}
|
|
|
|
for i, expr := range e.selectAST.Expression.Expressions {
|
|
v, err := expr.evalNode(input)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Pick output column names
|
|
if expr.As != "" {
|
|
output, err = output.Set(expr.As, v)
|
|
} else if comp, ok := getLastKeypathComponent(expr.Expression); ok {
|
|
output, err = output.Set(comp, v)
|
|
} else {
|
|
output, err = output.Set(fmt.Sprintf("_%d", i+1), v)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Update count of records output.
|
|
if e.limitValue > -1 {
|
|
e.outputCount++
|
|
}
|
|
|
|
return output, nil
|
|
}
|
|
|
|
// LimitReached - returns true if the number of records output has
|
|
// reached the value of the `LIMIT` clause.
|
|
func (e *SelectStatement) LimitReached() bool {
|
|
if e.limitValue == -1 {
|
|
return false
|
|
}
|
|
return e.outputCount >= e.limitValue
|
|
}
|