diff --git a/go.mod b/go.mod index ec2918e..cac5089 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.20 require ( git.frostfs.info/TrueCloudLab/frostfs-contract v0.18.1-0.20231129062201-a1b61d394958 + github.com/google/uuid v1.3.0 github.com/nspcc-dev/neo-go v0.103.0 github.com/stretchr/testify v1.8.4 golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 @@ -12,7 +13,6 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/hashicorp/golang-lru v0.6.0 // indirect github.com/mr-tron/base58 v1.2.0 // indirect github.com/nspcc-dev/go-ordered-json v0.0.0-20220111165707-25110be27d22 // indirect diff --git a/pkg/chain/chain.go b/pkg/chain/chain.go index ec718a5..d72611f 100644 --- a/pkg/chain/chain.go +++ b/pkg/chain/chain.go @@ -45,7 +45,7 @@ func (id *ID) UnmarshalJSON(data []byte) error { } func (c *Chain) Bytes() []byte { - data, err := json.Marshal(c) + data, err := c.MarshalBinary() if err != nil { panic(err) } @@ -53,7 +53,7 @@ func (c *Chain) Bytes() []byte { } func (c *Chain) DecodeBytes(b []byte) error { - return json.Unmarshal(b, c) + return c.UnmarshalBinary(b) } type Rule struct { diff --git a/pkg/chain/marshal.go b/pkg/chain/marshal.go new file mode 100644 index 0000000..343eca4 --- /dev/null +++ b/pkg/chain/marshal.go @@ -0,0 +1,257 @@ +package chain + +import ( + "encoding" + "fmt" + + "git.frostfs.info/TrueCloudLab/policy-engine/pkg/marshal" +) + +const ( + ChainMarshalVersion uint8 = 0 // increase if breaking change +) + +var ( + _ encoding.BinaryMarshaler = (*Chain)(nil) + _ encoding.BinaryUnmarshaler = (*Chain)(nil) +) + +func (c *Chain) MarshalBinary() ([]byte, error) { + s := marshal.UInt8Size // Marshaller version + s += marshal.UInt8Size // Chain version + s += marshal.StringSize(string(c.ID)) + s += marshal.SliceSize(c.Rules, ruleSize) + s += marshal.UInt8Size // MatchType + + buf := make([]byte, s) + var offset int + var err error + offset, err = marshal.UInt8Marshal(buf, offset, marshal.Version) + if err != nil { + return nil, err + } + offset, err = marshal.UInt8Marshal(buf, offset, ChainMarshalVersion) + if err != nil { + return nil, err + } + offset, err = marshal.StringMarshal(buf, offset, string(c.ID)) + if err != nil { + return nil, err + } + offset, err = marshal.SliceMarshal(buf, offset, c.Rules, marshalRule) + if err != nil { + return nil, err + } + offset, err = marshal.UInt8Marshal(buf, offset, uint8(c.MatchType)) + if err != nil { + return nil, err + } + + if err := marshal.VerifyMarshal(buf, offset); err != nil { + return nil, err + } + return buf, nil +} + +func (c *Chain) UnmarshalBinary(data []byte) error { + var offset int + + marshallerVersion, offset, err := marshal.UInt8Unmarshal(data, offset) + if err != nil { + return err + } + if marshallerVersion != marshal.Version { + return fmt.Errorf("unsupported marshaller version %d", marshallerVersion) + } + + chainVersion, offset, err := marshal.UInt8Unmarshal(data, offset) + if err != nil { + return err + } + if chainVersion != ChainMarshalVersion { + return fmt.Errorf("unsupported chain version %d", chainVersion) + } + + idStr, offset, err := marshal.StringUnmarshal(data, offset) + if err != nil { + return err + } + c.ID = ID(idStr) + + c.Rules, offset, err = marshal.SliceUnmarshal(data, offset, unmarshalRule) + if err != nil { + return err + } + + matchTypeV, offset, err := marshal.UInt8Unmarshal(data, offset) + if err != nil { + return err + } + c.MatchType = MatchType(matchTypeV) + + return marshal.VerifyUnmarshal(data, offset) +} + +func ruleSize(r Rule) int { + s := marshal.ByteSize // Status + s += actionsSize(r.Actions) + s += resourcesSize(r.Resources) + s += marshal.BoolSize // Any + s += marshal.SliceSize(r.Condition, conditionSize) + return s +} + +func marshalRule(buf []byte, offset int, r Rule) (int, error) { + offset, err := marshal.ByteMarshal(buf, offset, byte(r.Status)) + if err != nil { + return 0, err + } + offset, err = marshalActions(buf, offset, r.Actions) + if err != nil { + return 0, err + } + offset, err = marshalResources(buf, offset, r.Resources) + if err != nil { + return 0, err + } + offset, err = marshal.BoolMarshal(buf, offset, r.Any) + if err != nil { + return 0, err + } + return marshal.SliceMarshal(buf, offset, r.Condition, marshalCondition) +} + +func unmarshalRule(buf []byte, offset int) (Rule, int, error) { + var r Rule + statusV, offset, err := marshal.ByteUnmarshal(buf, offset) + if err != nil { + return Rule{}, 0, err + } + r.Status = Status(statusV) + + r.Actions, offset, err = unmarshalActions(buf, offset) + if err != nil { + return Rule{}, 0, err + } + + r.Resources, offset, err = unmarshalResources(buf, offset) + if err != nil { + return Rule{}, 0, err + } + + r.Any, offset, err = marshal.BoolUnmarshal(buf, offset) + if err != nil { + return Rule{}, 0, err + } + + r.Condition, offset, err = marshal.SliceUnmarshal(buf, offset, unmarshalCondition) + if err != nil { + return Rule{}, 0, err + } + + return r, offset, nil +} + +func actionsSize(a Actions) int { + return marshal.BoolSize + // Inverted + marshal.SliceSize(a.Names, marshal.StringSize) +} + +func marshalActions(buf []byte, offset int, a Actions) (int, error) { + offset, err := marshal.BoolMarshal(buf, offset, a.Inverted) + if err != nil { + return 0, err + } + return marshal.SliceMarshal(buf, offset, a.Names, marshal.StringMarshal) +} + +func unmarshalActions(buf []byte, offset int) (Actions, int, error) { + var a Actions + var err error + a.Inverted, offset, err = marshal.BoolUnmarshal(buf, offset) + if err != nil { + return Actions{}, 0, err + } + a.Names, offset, err = marshal.SliceUnmarshal(buf, offset, marshal.StringUnmarshal) + if err != nil { + return Actions{}, 0, err + } + return a, offset, nil +} + +func resourcesSize(r Resources) int { + return marshal.BoolSize + // Inverted + marshal.SliceSize(r.Names, marshal.StringSize) +} + +func marshalResources(buf []byte, offset int, r Resources) (int, error) { + offset, err := marshal.BoolMarshal(buf, offset, r.Inverted) + if err != nil { + return 0, err + } + return marshal.SliceMarshal(buf, offset, r.Names, marshal.StringMarshal) +} + +func unmarshalResources(buf []byte, offset int) (Resources, int, error) { + var r Resources + var err error + r.Inverted, offset, err = marshal.BoolUnmarshal(buf, offset) + if err != nil { + return Resources{}, 0, err + } + r.Names, offset, err = marshal.SliceUnmarshal(buf, offset, marshal.StringUnmarshal) + if err != nil { + return Resources{}, 0, err + } + return r, offset, nil +} + +func conditionSize(c Condition) int { + return marshal.ByteSize + // Op + marshal.ByteSize + // Object + marshal.StringSize(c.Key) + + marshal.StringSize(c.Value) +} + +func marshalCondition(buf []byte, offset int, c Condition) (int, error) { + offset, err := marshal.ByteMarshal(buf, offset, byte(c.Op)) + if err != nil { + return 0, err + } + offset, err = marshal.ByteMarshal(buf, offset, byte(c.Object)) + if err != nil { + return 0, err + } + offset, err = marshal.StringMarshal(buf, offset, c.Key) + if err != nil { + return 0, err + } + return marshal.StringMarshal(buf, offset, c.Value) +} + +func unmarshalCondition(buf []byte, offset int) (Condition, int, error) { + var c Condition + opV, offset, err := marshal.ByteUnmarshal(buf, offset) + if err != nil { + return Condition{}, 0, err + } + c.Op = ConditionType(opV) + + obV, offset, err := marshal.ByteUnmarshal(buf, offset) + if err != nil { + return Condition{}, 0, err + } + c.Object = ObjectType(obV) + + c.Key, offset, err = marshal.StringUnmarshal(buf, offset) + if err != nil { + return Condition{}, 0, err + } + + c.Value, offset, err = marshal.StringUnmarshal(buf, offset) + if err != nil { + return Condition{}, 0, err + } + + return c, offset, nil +} diff --git a/pkg/chain/marshal_test.go b/pkg/chain/marshal_test.go new file mode 100644 index 0000000..bc6045f --- /dev/null +++ b/pkg/chain/marshal_test.go @@ -0,0 +1,272 @@ +package chain + +import ( + "fmt" + "testing" + + "git.frostfs.info/TrueCloudLab/policy-engine/schema/native" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestChainMarshalling(t *testing.T) { + t.Parallel() + for _, id := range generateTestIDs() { + for _, rules := range generateTestRules() { + for _, matchType := range generateTestMatchTypes() { + performMarshalTest(t, id, rules, matchType) + } + } + } +} + +func TestInvalidChainData(t *testing.T) { + var ch Chain + require.Error(t, ch.UnmarshalBinary(nil)) + require.Error(t, ch.UnmarshalBinary([]byte{})) + require.Error(t, ch.UnmarshalBinary([]byte{1, 2, 3})) + require.Error(t, ch.UnmarshalBinary([]byte("\x00\x00:aws:iam::namespace:group/so\x82\x82\x82\x82\x82\x82u\x82"))) +} + +func FuzzUnmarshal(f *testing.F) { + for _, id := range generateTestIDs() { + for _, rules := range generateTestRules() { + for _, matchType := range generateTestMatchTypes() { + + chain := Chain{ + ID: id, + Rules: rules, + MatchType: matchType, + } + data, err := chain.MarshalBinary() + require.NoError(f, err) + f.Add(data) + } + } + } + + f.Fuzz(func(t *testing.T, data []byte) { + var ch Chain + require.NotPanics(t, func() { + _ = ch.UnmarshalBinary(data) + }) + }) +} + +func performMarshalTest(t *testing.T, id ID, r []Rule, mt MatchType) { + chain := Chain{ + ID: id, + Rules: r, + MatchType: mt, + } + data, err := chain.MarshalBinary() + require.NoError(t, err) + + var unmarshalledChain Chain + require.NoError(t, unmarshalledChain.UnmarshalBinary(data)) + + require.Equal(t, chain, unmarshalledChain) +} + +func generateTestIDs() []ID { + return []ID{ + ID(""), + ID(uuid.New().String()), + ID("*::/"), + ID("avada kedavra"), + ID("arn:aws:iam::namespace:group/some_group"), + ID("$Object:homomorphicHash"), + ID("native:container/ns/9LPLUFZpEmfidG4n44vi2cjXKXSqWT492tCvLJiJ8W1J"), + } +} + +func generateTestRules() [][]Rule { + result := [][]Rule{ + nil, + {}, + {}, + } + + for _, st := range generateTestStatuses() { + for _, act := range generateTestActions() { + for _, res := range generateTestResources() { + for _, cond := range generateTestConditions() { + result[2] = append(result[2], Rule{ + Status: st, + Actions: act, + Resources: res, + Condition: cond, + Any: true, + }) + result[2] = append(result[2], Rule{ + Status: st, + Actions: act, + Resources: res, + Condition: cond, + }) + } + } + } + } + + return result +} + +func generateTestStatuses() []Status { + return []Status{ + Allow, + NoRuleFound, + AccessDenied, + QuotaLimitReached, + } +} + +func generateTestActions() []Actions { + return []Actions{ + { + Inverted: true, + Names: nil, + }, + { + Names: nil, + }, + { + Inverted: true, + Names: []string{}, + }, + { + Names: []string{}, + }, + { + Inverted: true, + Names: []string{native.MethodPutObject}, + }, + { + Names: []string{native.MethodPutObject}, + }, + { + Inverted: true, + Names: []string{native.MethodPutObject, native.MethodDeleteContainer, native.MethodDeleteObject}, + }, + { + Names: []string{native.MethodPutObject, native.MethodDeleteContainer, native.MethodDeleteObject}, + }, + } +} + +func generateTestResources() []Resources { + return []Resources{ + { + Inverted: true, + Names: nil, + }, + { + Names: nil, + }, + { + Inverted: true, + Names: []string{}, + }, + { + Names: []string{}, + }, + { + Inverted: true, + Names: []string{native.ResourceFormatAllObjects}, + }, + { + Names: []string{native.ResourceFormatAllObjects}, + }, + { + Inverted: true, + Names: []string{ + native.ResourceFormatAllObjects, + fmt.Sprintf(native.ResourceFormatRootContainer, "9LPLUFZpEmfidG4n44vi2cjXKXSqWT492tCvLJiJ8W1J"), + }, + }, + { + Names: []string{ + native.ResourceFormatAllObjects, + fmt.Sprintf(native.ResourceFormatRootContainer, "9LPLUFZpEmfidG4n44vi2cjXKXSqWT492tCvLJiJ8W1J"), + }, + }, + } +} + +func generateTestConditions() [][]Condition { + result := [][]Condition{ + nil, + {}, + {}, + } + + for _, ct := range generateTestConditionTypes() { + for _, ot := range generateObjectTypes() { + result[2] = append(result[2], Condition{ + Op: ct, + Object: ot, + Key: "", + Value: "", + }) + + result[2] = append(result[2], Condition{ + Op: ct, + Object: ot, + Key: "key", + Value: "", + }) + + result[2] = append(result[2], Condition{ + Op: ct, + Object: ot, + Key: "", + Value: "value", + }) + + result[2] = append(result[2], Condition{ + Op: ct, + Object: ot, + Key: "key", + Value: "value", + }) + } + } + + return result +} + +func generateTestConditionTypes() []ConditionType { + return []ConditionType{ + CondStringEquals, + CondStringNotEquals, + CondStringEqualsIgnoreCase, + CondStringNotEqualsIgnoreCase, + CondStringLike, + CondStringNotLike, + CondStringLessThan, + CondStringLessThanEquals, + CondStringGreaterThan, + CondStringGreaterThanEquals, + CondNumericEquals, + CondNumericNotEquals, + CondNumericLessThan, + CondNumericLessThanEquals, + CondNumericGreaterThan, + CondNumericGreaterThanEquals, + CondSliceContains, + } +} + +func generateObjectTypes() []ObjectType { + return []ObjectType{ + ObjectResource, + ObjectRequest, + } +} + +func generateTestMatchTypes() []MatchType { + return []MatchType{ + MatchTypeDenyPriority, + MatchTypeFirstMatch, + } +} diff --git a/pkg/marshal/marshal.go b/pkg/marshal/marshal.go new file mode 100644 index 0000000..e296de6 --- /dev/null +++ b/pkg/marshal/marshal.go @@ -0,0 +1,267 @@ +package marshal + +import ( + "encoding/binary" + "fmt" +) + +const ( + Version byte = 0 // increase if breaking change + + ByteSize int = 1 + UInt8Size int = ByteSize + BoolSize int = ByteSize + + nilSlice int64 = -1 + nilSliceSize int = 1 + + byteTrue uint8 = 1 + byteFalse uint8 = 0 + + // maxSliceLen taken from https://github.com/neo-project/neo/blob/38218bbee5bbe8b33cd8f9453465a19381c9a547/src/Neo/IO/Helper.cs#L77 + maxSliceLen = 0x1000000 +) + +type MarshallerError struct { + errMsg string + offset int +} + +func (e *MarshallerError) Error() string { + if e == nil { + return "" + } + if e.offset < 0 { + return e.errMsg + } + return fmt.Sprintf("%s (offset: %d)", e.errMsg, e.offset) +} + +func errBufTooSmall(t string, marshal bool, offset int) error { + action := "unmarshal" + if marshal { + action = "marshal" + } + return &MarshallerError{ + errMsg: fmt.Sprintf("not enough bytes left to %s value of type '%s'", action, t), + offset: offset, + } +} + +func VerifyMarshal(buf []byte, lastOffset int) error { + if len(buf) != lastOffset { + return &MarshallerError{ + errMsg: "actual data size differs from expected", + offset: -1, + } + } + return nil +} + +func VerifyUnmarshal(buf []byte, lastOffset int) error { + if len(buf) != lastOffset { + return &MarshallerError{ + errMsg: "unmarshalled bytes left", + } + } + return nil +} + +func SliceSize[T any](slice []T, sizeOf func(T) int) int { + if slice == nil { + return nilSliceSize + } + s := Int64Size(int64(len(slice))) + for _, v := range slice { + s += sizeOf(v) + } + return s +} + +func SliceMarshal[T any](buf []byte, offset int, slice []T, marshalT func([]byte, int, T) (int, error)) (int, error) { + if slice == nil { + return Int64Marshal(buf, offset, nilSlice) + } + if len(slice) > maxSliceLen { + return 0, &MarshallerError{ + errMsg: fmt.Sprintf("slice size if too big: '%d'", len(slice)), + offset: offset, + } + } + offset, err := Int64Marshal(buf, offset, int64(len(slice))) + if err != nil { + return 0, err + } + for _, v := range slice { + offset, err = marshalT(buf, offset, v) + if err != nil { + return 0, err + } + } + return offset, nil +} + +func SliceUnmarshal[T any](buf []byte, offset int, unmarshalT func(buf []byte, offset int) (T, int, error)) ([]T, int, error) { + size, offset, err := Int64Unmarshal(buf, offset) + if err != nil { + return nil, 0, err + } + if size == nilSlice { + return nil, offset, nil + } + if size > maxSliceLen { + return nil, 0, &MarshallerError{ + errMsg: fmt.Sprintf("slice size if too big: '%d'", size), + offset: offset, + } + } + if size < 0 { + return nil, 0, &MarshallerError{ + errMsg: fmt.Sprintf("invalid slice size: '%d'", size), + offset: offset, + } + } + result := make([]T, size) + for idx := 0; idx < len(result); idx++ { + result[idx], offset, err = unmarshalT(buf, offset) + if err != nil { + return nil, 0, err + } + } + return result, offset, nil +} + +func Int64Size(v int64) int { + // https://cs.opensource.google/go/go/+/master:src/encoding/binary/varint.go;l=92;drc=dac9b9ddbd5160c5f4552410f5f8281bd5eed38c + // and + // https://cs.opensource.google/go/go/+/master:src/encoding/binary/varint.go;l=41;drc=dac9b9ddbd5160c5f4552410f5f8281bd5eed38c + ux := uint64(v) << 1 + if v < 0 { + ux = ^ux + } + s := 0 + for ux >= 0x80 { + s++ + ux >>= 7 + } + return s + 1 +} + +func Int64Marshal(buf []byte, offset int, v int64) (int, error) { + if len(buf)-offset < Int64Size(v) { + return 0, errBufTooSmall("int64", true, offset) + } + return offset + binary.PutVarint(buf[offset:], v), nil +} + +func Int64Unmarshal(buf []byte, offset int) (int64, int, error) { + v, read := binary.Varint(buf[offset:]) + if read == 0 { + return 0, 0, errBufTooSmall("int64", false, offset) + } + if read < 0 { + return 0, 0, &MarshallerError{ + errMsg: "int64 unmarshal overflow", + offset: offset, + } + } + return v, offset + read, nil +} + +func StringSize(s string) int { + return Int64Size(int64(len(s))) + len(s) +} + +func StringMarshal(buf []byte, offset int, s string) (int, error) { + if len(s) > maxSliceLen { + return 0, &MarshallerError{ + errMsg: fmt.Sprintf("string is too long: '%d'", len(s)), + offset: offset, + } + } + if len(buf)-offset < Int64Size(int64(len(s)))+len(s) { + return 0, errBufTooSmall("string", true, offset) + } + + offset, err := Int64Marshal(buf, offset, int64(len(s))) + if err != nil { + return 0, err + } + if s == "" { + return offset, nil + } + return offset + copy(buf[offset:], s), nil +} + +func StringUnmarshal(buf []byte, offset int) (string, int, error) { + size, offset, err := Int64Unmarshal(buf, offset) + if err != nil { + return "", 0, err + } + if size == 0 { + return "", offset, nil + } + if size > maxSliceLen { + return "", 0, &MarshallerError{ + errMsg: fmt.Sprintf("string is too long: '%d'", size), + offset: offset, + } + } + if size < 0 { + return "", 0, &MarshallerError{ + errMsg: fmt.Sprintf("invalid string size: '%d'", size), + offset: offset, + } + } + if len(buf)-offset < int(size) { + return "", 0, errBufTooSmall("string", false, offset) + } + return string(buf[offset : offset+int(size)]), offset + int(size), nil +} + +func UInt8Marshal(buf []byte, offset int, value uint8) (int, error) { + if len(buf)-offset < 1 { + return 0, errBufTooSmall("uint8", true, offset) + } + buf[offset] = value + return offset + 1, nil +} + +func UInt8Unmarshal(buf []byte, offset int) (uint8, int, error) { + if len(buf)-offset < 1 { + return 0, 0, errBufTooSmall("uint8", false, offset) + } + return buf[offset], offset + 1, nil +} + +func ByteMarshal(buf []byte, offset int, value byte) (int, error) { + return UInt8Marshal(buf, offset, value) +} + +func ByteUnmarshal(buf []byte, offset int) (byte, int, error) { + return UInt8Unmarshal(buf, offset) +} + +func BoolMarshal(buf []byte, offset int, value bool) (int, error) { + if value { + return UInt8Marshal(buf, offset, byteTrue) + } + return UInt8Marshal(buf, offset, byteFalse) +} + +func BoolUnmarshal(buf []byte, offset int) (bool, int, error) { + v, offset, err := UInt8Unmarshal(buf, offset) + if err != nil { + return false, 0, err + } + if v == byteTrue { + return true, offset, nil + } + if v == byteFalse { + return false, offset, nil + } + return false, 0, &MarshallerError{ + errMsg: fmt.Sprintf("invalid marshalled value for bool: %d", v), + offset: offset - BoolSize, + } +} diff --git a/pkg/marshal/marshal_test.go b/pkg/marshal/marshal_test.go new file mode 100644 index 0000000..5d3babd --- /dev/null +++ b/pkg/marshal/marshal_test.go @@ -0,0 +1,313 @@ +package marshal + +import ( + "encoding/binary" + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMarshalling(t *testing.T) { + t.Parallel() + t.Run("slice", func(t *testing.T) { + t.Parallel() + t.Run("nil slice", func(t *testing.T) { + t.Parallel() + + var int64s []int64 + expectedSize := SliceSize(int64s, Int64Size) + require.Equal(t, 1, expectedSize) + buf := make([]byte, expectedSize) + offset, err := SliceMarshal(buf, 0, int64s, Int64Marshal) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + + result, offset, err := SliceUnmarshal(buf, 0, Int64Unmarshal) + require.NoError(t, err) + require.NoError(t, VerifyUnmarshal(buf, offset)) + require.Nil(t, result) + }) + + t.Run("empty slice", func(t *testing.T) { + t.Parallel() + + int64s := make([]int64, 0) + expectedSize := SliceSize(int64s, Int64Size) + require.Equal(t, 1, expectedSize) + buf := make([]byte, expectedSize) + offset, err := SliceMarshal(buf, 0, int64s, Int64Marshal) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + + result, offset, err := SliceUnmarshal(buf, 0, Int64Unmarshal) + require.NoError(t, err) + require.NoError(t, VerifyUnmarshal(buf, offset)) + require.NotNil(t, result) + require.Len(t, result, 0) + }) + + t.Run("non empty slice", func(t *testing.T) { + t.Parallel() + + int64s := make([]int64, 100) + for i := range int64s { + int64s[i] = int64(i) + } + expectedSize := SliceSize(int64s, Int64Size) + buf := make([]byte, expectedSize) + offset, err := SliceMarshal(buf, 0, int64s, Int64Marshal) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + + result, offset, err := SliceUnmarshal(buf, 0, Int64Unmarshal) + require.NoError(t, err) + require.NoError(t, VerifyUnmarshal(buf, offset)) + require.Equal(t, int64s, result) + }) + + t.Run("corrupted slice size", func(t *testing.T) { + t.Parallel() + + int64s := make([]int64, 100) + for i := range int64s { + int64s[i] = int64(i) + } + expectedSize := SliceSize(int64s, Int64Size) + buf := make([]byte, expectedSize) + offset, err := SliceMarshal(buf, 0, int64s, Int64Marshal) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + + for i := 0; i < binary.MaxVarintLen64; i++ { + buf[i] = 129 + } + + _, _, err = SliceUnmarshal(buf, 0, Int64Unmarshal) + var mErr *MarshallerError + require.ErrorAs(t, err, &mErr) + + for i := 0; i < binary.MaxVarintLen64; i++ { + buf[i] = 127 + } + _, _, err = SliceUnmarshal(buf, 0, Int64Unmarshal) + require.ErrorAs(t, err, &mErr) + }) + + t.Run("corrupted slice item", func(t *testing.T) { + t.Parallel() + + int64s := make([]int64, 100) + for i := range int64s { + int64s[i] = int64(i) + } + expectedSize := SliceSize(int64s, Int64Size) + buf := make([]byte, expectedSize) + offset, err := SliceMarshal(buf, 0, int64s, Int64Marshal) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + + for i := 2; i < binary.MaxVarintLen64+2; i++ { + buf[i] = 129 + } + + _, _, err = SliceUnmarshal(buf, 0, Int64Unmarshal) + var mErr *MarshallerError + require.ErrorAs(t, err, &mErr) + }) + + t.Run("small buffer", func(t *testing.T) { + t.Parallel() + + int64s := make([]int64, 100) + for i := range int64s { + int64s[i] = int64(i) + } + buf := make([]byte, 1) + _, err := SliceMarshal(buf, 0, int64s, Int64Marshal) + var mErr *MarshallerError + require.ErrorAs(t, err, &mErr) + + buf = make([]byte, 10) + _, err = SliceMarshal(buf, 0, int64s, Int64Marshal) + require.ErrorAs(t, err, &mErr) + }) + }) + + t.Run("int64", func(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + require.Equal(t, 1, Int64Size(0)) + require.Equal(t, binary.MaxVarintLen64, Int64Size(math.MaxInt64)) + require.Equal(t, binary.MaxVarintLen64, Int64Size(math.MinInt64)) + + for _, v := range []int64{0, math.MinInt64, math.MaxInt64} { + size := Int64Size(v) + buf := make([]byte, size) + offset, err := Int64Marshal(buf, 0, v) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + + uv, offset, err := Int64Unmarshal(buf, 0) + require.NoError(t, err) + require.NoError(t, VerifyUnmarshal(buf, offset)) + require.Equal(t, v, uv) + } + }) + + t.Run("invalid buffer", func(t *testing.T) { + t.Parallel() + + var mErr *MarshallerError + + _, err := Int64Marshal([]byte{}, 0, 100500) + require.ErrorAs(t, err, &mErr) + + _, _, err = Int64Unmarshal(nil, 0) + require.ErrorAs(t, err, &mErr) + }) + + t.Run("overflow", func(t *testing.T) { + t.Parallel() + + var mErr *MarshallerError + + var v int64 = math.MaxInt64 + buf := make([]byte, Int64Size(v)) + _, err := Int64Marshal(buf, 0, v) + require.NoError(t, err) + + buf[9] = 2 + + _, _, err = Int64Unmarshal(buf, 0) + require.ErrorAs(t, err, &mErr) + }) + }) + + t.Run("string", func(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + for _, v := range []string{ + "", "arn:aws:iam::namespace:group/some_group", "$Object:homomorphicHash", + "native:container/ns/9LPLUFZpEmfidG4n44vi2cjXKXSqWT492tCvLJiJ8W1J", + } { + size := StringSize(v) + buf := make([]byte, size) + offset, err := StringMarshal(buf, 0, v) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + + uv, offset, err := StringUnmarshal(buf, 0) + require.NoError(t, err) + require.NoError(t, VerifyUnmarshal(buf, offset)) + require.Equal(t, v, uv) + } + }) + + t.Run("invalid buffer", func(t *testing.T) { + t.Parallel() + + str := "avada kedavra" + + var mErr *MarshallerError + _, err := StringMarshal(nil, 0, str) + require.ErrorAs(t, err, &mErr) + + _, _, err = StringUnmarshal(nil, 0) + require.ErrorAs(t, err, &mErr) + + buf := make([]byte, StringSize(str)) + offset, err := StringMarshal(buf, 0, str) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + buf = buf[:len(buf)-1] + _, _, err = StringUnmarshal(buf, 0) + require.ErrorAs(t, err, &mErr) + }) + }) + + t.Run("uint8, byte", func(t *testing.T) { + t.Parallel() + + for _, v := range []byte{0, 8, 16, 32, 64, 128, 255} { + buf := make([]byte, ByteSize) + offset, err := ByteMarshal(buf, 0, v) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + + ub, offset, err := ByteUnmarshal(buf, 0) + require.NoError(t, err) + require.NoError(t, VerifyUnmarshal(buf, offset)) + require.Equal(t, v, ub) + + buf = make([]byte, UInt8Size) + offset, err = UInt8Marshal(buf, 0, v) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + + uu, offset, err := UInt8Unmarshal(buf, 0) + require.NoError(t, err) + require.NoError(t, VerifyUnmarshal(buf, offset)) + require.Equal(t, v, uu) + } + }) + + t.Run("bool", func(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + for _, v := range []bool{false, true} { + buf := make([]byte, BoolSize) + offset, err := BoolMarshal(buf, 0, v) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + + ub, offset, err := BoolUnmarshal(buf, 0) + require.NoError(t, err) + require.NoError(t, VerifyUnmarshal(buf, offset)) + require.Equal(t, v, ub) + } + }) + + t.Run("invalid value", func(t *testing.T) { + t.Parallel() + buf := make([]byte, BoolSize) + offset, err := BoolMarshal(buf, 0, true) + require.NoError(t, err) + require.NoError(t, VerifyMarshal(buf, offset)) + + buf[0] = 2 + + _, _, err = BoolUnmarshal(buf, 0) + var mErr *MarshallerError + require.ErrorAs(t, err, &mErr) + }) + + t.Run("invalid buffer", func(t *testing.T) { + t.Parallel() + var mErr *MarshallerError + + _, err := BoolMarshal(nil, 0, true) + require.ErrorAs(t, err, &mErr) + + buf := append(make([]byte, BoolSize), 100) + offset, err := BoolMarshal(buf, 0, true) + require.NoError(t, err) + require.ErrorAs(t, VerifyMarshal(buf, offset), &mErr) + + v, offset, err := BoolUnmarshal(buf, 0) + require.NoError(t, err) + require.True(t, v) + require.ErrorAs(t, VerifyUnmarshal(buf, offset), &mErr) + + _, _, err = BoolUnmarshal(nil, 0) + require.ErrorAs(t, err, &mErr) + }) + }) +}