diff --git a/pkg/services/oracle/filter_test.go b/pkg/services/oracle/filter_test.go index 134d94b6d..27a08b433 100644 --- a/pkg/services/oracle/filter_test.go +++ b/pkg/services/oracle/filter_test.go @@ -1,6 +1,7 @@ package oracle import ( + "strings" "testing" "github.com/stretchr/testify/require" @@ -49,3 +50,34 @@ func TestFilter(t *testing.T) { require.Error(t, err) }) } + +func TestFilterOOM(t *testing.T) { + construct := func(depth, width int) string { + data := `$` + for i := 0; i < depth; i++ { + data = data + `[0` + for j := 1; j < width; j++ { + data = data + `,0` + } + data = data + `]` + } + return data + } + t.Run("good", func(t *testing.T) { + expected := "[" + strings.Repeat("{},", 1023) + "{}]" + data := construct(2, 32) + actual, err := filter([]byte("[[{}]]"), data) + require.NoError(t, err) + require.JSONEq(t, expected, string(actual)) + }) + t.Run("too big", func(t *testing.T) { + data := construct(3, 32) + _, err := filter([]byte("[[[[[[{}]]]]]]"), data) + require.Error(t, err) + }) + t.Run("no oom", func(t *testing.T) { + data := construct(6, 64) + _, err := filter([]byte("[[[[[[{}]]]]]]"), data) + require.Error(t, err) + }) +} diff --git a/pkg/services/oracle/jsonpath/jsonpath.go b/pkg/services/oracle/jsonpath/jsonpath.go index 2afab962a..061b3af55 100644 --- a/pkg/services/oracle/jsonpath/jsonpath.go +++ b/pkg/services/oracle/jsonpath/jsonpath.go @@ -33,7 +33,10 @@ const ( pathNumber ) -const maxNestingDepth = 6 +const ( + maxNestingDepth = 6 + maxObjects = 1024 +) // Get returns substructures of value selected by path. // The result is always non-nil unless path is invalid. @@ -63,7 +66,7 @@ func Get(path string, value interface{}) ([]interface{}, bool) { objs, ok = p.processLeftBracket(objs) } - if !ok { + if !ok || maxObjects < len(objs) { return nil, false } } @@ -196,8 +199,14 @@ func (p *pathParser) descend(objs []interface{}) ([]interface{}, bool) { for i := range objs { switch obj := objs[i].(type) { case []interface{}: + if maxObjects < len(values)+len(obj) { + return nil, false + } values = append(values, obj...) case json.OrderedObject: + if maxObjects < len(values)+len(obj) { + return nil, false + } for i := range obj { values = append(values, obj[i].Value) } @@ -218,6 +227,9 @@ func (p *pathParser) descendRecursive(objs []interface{}) ([]interface{}, bool) for len(objs) > 0 { newObjs, _ := p.descendByIdentAux(objs, false, val) + if maxObjects < len(values)+len(newObjs) { + return nil, false + } values = append(values, newObjs...) objs, _ = p.descend(objs) } @@ -248,6 +260,9 @@ func (p *pathParser) descendByIdentAux(objs []interface{}, checkDepth bool, name for j := range names { for k := range obj { if obj[k].Key == names[j] { + if maxObjects < len(values)+1 { + return nil, false + } values = append(values, obj[k].Value) break } @@ -276,6 +291,9 @@ func (p *pathParser) descendByIndex(objs []interface{}, indices ...int) ([]inter j += len(obj) } if 0 <= j && j < len(obj) { + if maxObjects < len(values)+1 { + return nil, false + } values = append(values, obj[j]) } } @@ -438,6 +456,9 @@ func (p *pathParser) descendByRange(objs []interface{}, start, end int) ([]inter if subEnd <= subStart { continue } + if maxObjects < len(values)+subEnd-subStart { + return nil, false + } values = append(values, arr[subStart:subEnd]...) } diff --git a/pkg/services/oracle/jsonpath/jsonpath_test.go b/pkg/services/oracle/jsonpath/jsonpath_test.go index 7f69f2bbf..eaeacbe2d 100644 --- a/pkg/services/oracle/jsonpath/jsonpath_test.go +++ b/pkg/services/oracle/jsonpath/jsonpath_test.go @@ -4,6 +4,7 @@ import ( "bytes" "math" "strconv" + "strings" "testing" json "github.com/nspcc-dev/go-ordered-json" @@ -210,6 +211,58 @@ func TestUnion(t *testing.T) { tc.testUnmarshalGet(t, js) }) } + + t.Run("big amount of intermediate objects", func(t *testing.T) { + // We want to fail as early as possible, this test covers all possible + // places where an overflow could first occur. The idea is that first steps + // construct intermediate array of 1000 < 1024, and the last step multiplies + // this amount by 2. + construct := func(width int, index string) string { + return "[" + strings.Repeat(index+",", width-1) + index + "]" + } + + t.Run("index, array", func(t *testing.T) { + jp := "$" + strings.Repeat(construct(10, "0"), 4) + _, ok := unmarshalGet(t, "[[[[{}]]]]", jp) + require.False(t, ok) + }) + + t.Run("asterisk, array", func(t *testing.T) { + jp := "$" + strings.Repeat(construct(10, `0`), 3) + ".*" + _, ok := unmarshalGet(t, `[[[[{},{}]]]]`, jp) + require.False(t, ok) + }) + + t.Run("range", func(t *testing.T) { + jp := "$" + strings.Repeat(construct(10, `0`), 3) + "[0:2]" + _, ok := unmarshalGet(t, `[[[[{},{}]]]]`, jp) + require.False(t, ok) + }) + + t.Run("recursive descent", func(t *testing.T) { + jp := "$" + strings.Repeat(construct(10, `0`), 3) + "..a" + _, ok := unmarshalGet(t, `[[[{"a":{"a":{}}}]]]`, jp) + require.False(t, ok) + }) + + t.Run("string union", func(t *testing.T) { + jp := "$" + strings.Repeat(construct(10, `0`), 3) + "['x','y']" + _, ok := unmarshalGet(t, `[[[{"x":{},"y":{}}]]]`, jp) + require.False(t, ok) + }) + + t.Run("index, map", func(t *testing.T) { + jp := "$" + strings.Repeat(construct(10, `"a"`), 4) + _, ok := unmarshalGet(t, `{"a":{"a":{"a":{"a":{}}}}}`, jp) + require.False(t, ok) + }) + + t.Run("asterisk, map", func(t *testing.T) { + jp := "$" + strings.Repeat(construct(10, `'a'`), 3) + ".*" + _, ok := unmarshalGet(t, `{"a":{"a":{"a":{"x":{},"y":{}}}}}`, jp) + require.False(t, ok) + }) + }) } // These tests are taken directly from C# code.