285 lines
7.9 KiB
Go
285 lines
7.9 KiB
Go
|
package erasurecode_test
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/rand"
|
||
|
"math"
|
||
|
"testing"
|
||
|
|
||
|
cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test"
|
||
|
objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
|
||
|
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/erasurecode"
|
||
|
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/transformer"
|
||
|
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
|
||
|
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/version"
|
||
|
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
func TestErasureCodeReconstruct(t *testing.T) {
|
||
|
const payloadSize = 99
|
||
|
const dataCount = 3
|
||
|
const parityCount = 2
|
||
|
|
||
|
// We would also like to test padding behaviour,
|
||
|
// so ensure padding is done.
|
||
|
require.NotZero(t, payloadSize%(dataCount+parityCount))
|
||
|
|
||
|
pk, err := keys.NewPrivateKey()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
original := newObject(t, payloadSize, pk)
|
||
|
|
||
|
c, err := erasurecode.NewConstructor(dataCount, parityCount)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
parts, err := c.Split(original, &pk.PrivateKey)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
t.Run("reconstruct header", func(t *testing.T) {
|
||
|
original := original.CutPayload()
|
||
|
parts := cloneSlice(parts)
|
||
|
for i := range parts {
|
||
|
parts[i] = parts[i].CutPayload()
|
||
|
}
|
||
|
t.Run("from data", func(t *testing.T) {
|
||
|
parts := cloneSlice(parts)
|
||
|
for i := dataCount; i < dataCount+parityCount; i++ {
|
||
|
parts[i] = nil
|
||
|
}
|
||
|
reconstructed, err := c.ReconstructHeader(parts)
|
||
|
require.NoError(t, err)
|
||
|
verifyReconstruction(t, original, reconstructed)
|
||
|
})
|
||
|
t.Run("from parity", func(t *testing.T) {
|
||
|
parts := cloneSlice(parts)
|
||
|
for i := 0; i < parityCount; i++ {
|
||
|
parts[i] = nil
|
||
|
}
|
||
|
reconstructed, err := c.ReconstructHeader(parts)
|
||
|
require.NoError(t, err)
|
||
|
verifyReconstruction(t, original, reconstructed)
|
||
|
|
||
|
t.Run("not enough shards", func(t *testing.T) {
|
||
|
parts[parityCount] = nil
|
||
|
_, err := c.ReconstructHeader(parts)
|
||
|
require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
|
||
|
})
|
||
|
})
|
||
|
t.Run("only nil parts", func(t *testing.T) {
|
||
|
parts := make([]*objectSDK.Object, len(parts))
|
||
|
_, err := c.ReconstructHeader(parts)
|
||
|
require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
|
||
|
})
|
||
|
t.Run("missing EC header", func(t *testing.T) {
|
||
|
parts := cloneSlice(parts)
|
||
|
parts[0] = deepCopy(t, parts[0])
|
||
|
parts[0].SetECHeader(nil)
|
||
|
|
||
|
_, err := c.ReconstructHeader(parts)
|
||
|
require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
|
||
|
})
|
||
|
t.Run("invalid index", func(t *testing.T) {
|
||
|
parts := cloneSlice(parts)
|
||
|
parts[0] = deepCopy(t, parts[0])
|
||
|
|
||
|
ec := parts[0].GetECHeader()
|
||
|
ec.SetIndex(1)
|
||
|
parts[0].SetECHeader(ec)
|
||
|
|
||
|
_, err := c.ReconstructHeader(parts)
|
||
|
require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
|
||
|
})
|
||
|
t.Run("invalid total", func(t *testing.T) {
|
||
|
parts := cloneSlice(parts)
|
||
|
parts[0] = deepCopy(t, parts[0])
|
||
|
|
||
|
ec := parts[0].GetECHeader()
|
||
|
ec.SetTotal(uint32(len(parts) + 1))
|
||
|
parts[0].SetECHeader(ec)
|
||
|
|
||
|
_, err := c.ReconstructHeader(parts)
|
||
|
require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
|
||
|
})
|
||
|
t.Run("inconsistent header length", func(t *testing.T) {
|
||
|
parts := cloneSlice(parts)
|
||
|
parts[0] = deepCopy(t, parts[0])
|
||
|
|
||
|
ec := parts[0].GetECHeader()
|
||
|
ec.SetHeaderLength(ec.HeaderLength() - 1)
|
||
|
parts[0].SetECHeader(ec)
|
||
|
|
||
|
_, err := c.ReconstructHeader(parts)
|
||
|
require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
|
||
|
})
|
||
|
t.Run("invalid header length", func(t *testing.T) {
|
||
|
parts := cloneSlice(parts)
|
||
|
for i := range parts {
|
||
|
parts[i] = deepCopy(t, parts[i])
|
||
|
|
||
|
ec := parts[0].GetECHeader()
|
||
|
ec.SetHeaderLength(math.MaxUint32)
|
||
|
parts[0].SetECHeader(ec)
|
||
|
}
|
||
|
|
||
|
_, err := c.ReconstructHeader(parts)
|
||
|
require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
|
||
|
})
|
||
|
})
|
||
|
t.Run("reconstruct data", func(t *testing.T) {
|
||
|
t.Run("from data", func(t *testing.T) {
|
||
|
parts := cloneSlice(parts)
|
||
|
for i := dataCount; i < dataCount+parityCount; i++ {
|
||
|
parts[i] = nil
|
||
|
}
|
||
|
reconstructed, err := c.Reconstruct(parts)
|
||
|
require.NoError(t, err)
|
||
|
verifyReconstruction(t, original, reconstructed)
|
||
|
})
|
||
|
t.Run("from parity", func(t *testing.T) {
|
||
|
parts := cloneSlice(parts)
|
||
|
for i := 0; i < parityCount; i++ {
|
||
|
parts[i] = nil
|
||
|
}
|
||
|
reconstructed, err := c.Reconstruct(parts)
|
||
|
require.NoError(t, err)
|
||
|
verifyReconstruction(t, original, reconstructed)
|
||
|
|
||
|
t.Run("not enough shards", func(t *testing.T) {
|
||
|
parts[parityCount] = nil
|
||
|
_, err := c.Reconstruct(parts)
|
||
|
require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
|
||
|
})
|
||
|
})
|
||
|
})
|
||
|
t.Run("reconstruct parts", func(t *testing.T) {
|
||
|
// We would like to also test that ReconstructParts doesn't perform
|
||
|
// excessive work, so ensure this test makes sense.
|
||
|
require.GreaterOrEqual(t, parityCount, 2)
|
||
|
|
||
|
t.Run("from data", func(t *testing.T) {
|
||
|
oldParts := parts
|
||
|
parts := cloneSlice(parts)
|
||
|
for i := dataCount; i < dataCount+parityCount; i++ {
|
||
|
parts[i] = nil
|
||
|
}
|
||
|
|
||
|
required := make([]bool, len(parts))
|
||
|
required[dataCount] = true
|
||
|
|
||
|
require.NoError(t, c.ReconstructParts(parts, required, nil))
|
||
|
|
||
|
old := deepCopy(t, oldParts[dataCount])
|
||
|
old.SetSignature(nil)
|
||
|
require.Equal(t, old, parts[dataCount])
|
||
|
|
||
|
for i := dataCount + 1; i < dataCount+parityCount; i++ {
|
||
|
require.Nil(t, parts[i])
|
||
|
}
|
||
|
})
|
||
|
t.Run("from parity", func(t *testing.T) {
|
||
|
oldParts := parts
|
||
|
parts := cloneSlice(parts)
|
||
|
for i := 0; i < parityCount; i++ {
|
||
|
parts[i] = nil
|
||
|
}
|
||
|
|
||
|
required := make([]bool, len(parts))
|
||
|
required[0] = true
|
||
|
|
||
|
require.NoError(t, c.ReconstructParts(parts, required, nil))
|
||
|
|
||
|
old := deepCopy(t, oldParts[0])
|
||
|
old.SetSignature(nil)
|
||
|
require.Equal(t, old, parts[0])
|
||
|
|
||
|
for i := 1; i < parityCount; i++ {
|
||
|
require.Nil(t, parts[i])
|
||
|
}
|
||
|
})
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func newObject(t *testing.T, size uint64, pk *keys.PrivateKey) *objectSDK.Object {
|
||
|
// Use transformer to form object to avoid potential bugs with yet another helper object creation in tests.
|
||
|
tt := &testTarget{}
|
||
|
p := transformer.NewPayloadSizeLimiter(transformer.Params{
|
||
|
Key: &pk.PrivateKey,
|
||
|
NextTargetInit: func() transformer.ObjectWriter { return tt },
|
||
|
NetworkState: dummyEpochSource(123),
|
||
|
MaxSize: size + 1,
|
||
|
WithoutHomomorphicHash: true,
|
||
|
})
|
||
|
cnr := cidtest.ID()
|
||
|
ver := version.Current()
|
||
|
hdr := objectSDK.New()
|
||
|
hdr.SetContainerID(cnr)
|
||
|
hdr.SetType(objectSDK.TypeRegular)
|
||
|
hdr.SetVersion(&ver)
|
||
|
|
||
|
var owner user.ID
|
||
|
user.IDFromKey(&owner, pk.PrivateKey.PublicKey)
|
||
|
hdr.SetOwnerID(owner)
|
||
|
|
||
|
var attr objectSDK.Attribute
|
||
|
attr.SetKey("somekey")
|
||
|
attr.SetValue("somevalue")
|
||
|
hdr.SetAttributes(attr)
|
||
|
|
||
|
expectedPayload := make([]byte, size)
|
||
|
_, _ = rand.Read(expectedPayload)
|
||
|
writeObject(t, context.Background(), p, hdr, expectedPayload)
|
||
|
require.Len(t, tt.objects, 1)
|
||
|
return tt.objects[0]
|
||
|
}
|
||
|
|
||
|
func writeObject(t *testing.T, ctx context.Context, target transformer.ChunkedObjectWriter, header *objectSDK.Object, payload []byte) *transformer.AccessIdentifiers {
|
||
|
require.NoError(t, target.WriteHeader(ctx, header))
|
||
|
|
||
|
_, err := target.Write(ctx, payload)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
ids, err := target.Close(ctx)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
return ids
|
||
|
}
|
||
|
|
||
|
func verifyReconstruction(t *testing.T, original, reconstructed *objectSDK.Object) {
|
||
|
require.True(t, reconstructed.VerifyIDSignature())
|
||
|
reconstructed.ToV2().SetMarshalData(nil)
|
||
|
original.ToV2().SetMarshalData(nil)
|
||
|
|
||
|
require.Equal(t, original, reconstructed)
|
||
|
}
|
||
|
|
||
|
func deepCopy(t *testing.T, obj *objectSDK.Object) *objectSDK.Object {
|
||
|
data, err := obj.Marshal()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
res := objectSDK.New()
|
||
|
require.NoError(t, res.Unmarshal(data))
|
||
|
return res
|
||
|
}
|
||
|
|
||
|
func cloneSlice[T any](src []T) []T {
|
||
|
dst := make([]T, len(src))
|
||
|
copy(dst, src)
|
||
|
return dst
|
||
|
}
|
||
|
|
||
|
type dummyEpochSource uint64
|
||
|
|
||
|
func (s dummyEpochSource) CurrentEpoch() uint64 {
|
||
|
return uint64(s)
|
||
|
}
|
||
|
|
||
|
type testTarget struct {
|
||
|
objects []*objectSDK.Object
|
||
|
}
|
||
|
|
||
|
func (tt *testTarget) WriteObject(_ context.Context, o *objectSDK.Object) error {
|
||
|
tt.objects = append(tt.objects, o)
|
||
|
return nil // AccessIdentifiers should not be used.
|
||
|
}
|