storage: use strings as keys for memory batch

Using pointers is just plain wrong here, because the batch can be updated with
newer values for the same keys.

Fixes Seek() to use HasPrefix also because this is the intended behavior.
This commit is contained in:
Roman Khimov 2019-10-07 17:05:53 +03:00
parent aab2f9a837
commit add9368e9d
5 changed files with 13 additions and 44 deletions

View file

@ -76,7 +76,7 @@ func (s *BoltDBStore) PutBatch(batch Batch) error {
return s.db.Batch(func(tx *bbolt.Tx) error { return s.db.Batch(func(tx *bbolt.Tx) error {
b := tx.Bucket(Bucket) b := tx.Bucket(Bucket)
for k, v := range batch.(*MemoryBatch).m { for k, v := range batch.(*MemoryBatch).m {
err := b.Put(*k, v) err := b.Put([]byte(k), v)
if err != nil { if err != nil {
return err return err
} }

View file

@ -3,27 +3,12 @@ package storage
import ( import (
"io/ioutil" "io/ioutil"
"os" "os"
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestBoltDBBatch(t *testing.T) {
boltDB := BoltDBStore{}
want := &MemoryBatch{m: map[*[]byte][]byte{}}
if got := boltDB.Batch(); !reflect.DeepEqual(got, want) {
t.Errorf("BoltDB Batch() = %v, want %v", got, want)
}
}
func TestBoltDBBatch_Len(t *testing.T) {
batch := &MemoryBatch{m: map[*[]byte][]byte{}}
want := len(map[*[]byte][]byte{})
assert.Equal(t, want, batch.Len())
}
func TestBoltDBBatch_PutBatchAndGet(t *testing.T) { func TestBoltDBBatch_PutBatchAndGet(t *testing.T) {
key := []byte("foo") key := []byte("foo")
keycopy := make([]byte, len(key)) keycopy := make([]byte, len(key))

View file

@ -1,7 +1,6 @@
package storage package storage
import ( import (
"encoding/hex"
"strings" "strings"
"sync" "sync"
) )
@ -15,16 +14,15 @@ type MemoryStore struct {
// MemoryBatch a in-memory batch compatible with MemoryStore. // MemoryBatch a in-memory batch compatible with MemoryStore.
type MemoryBatch struct { type MemoryBatch struct {
m map[*[]byte][]byte m map[string][]byte
} }
// Put implements the Batch interface. // Put implements the Batch interface.
func (b *MemoryBatch) Put(k, v []byte) { func (b *MemoryBatch) Put(k, v []byte) {
vcopy := make([]byte, len(v)) vcopy := make([]byte, len(v))
copy(vcopy, v) copy(vcopy, v)
kcopy := make([]byte, len(k)) kcopy := string(k)
copy(kcopy, k) b.m[kcopy] = vcopy
b.m[&kcopy] = vcopy
} }
// Len implements the Batch interface. // Len implements the Batch interface.
@ -43,7 +41,7 @@ func NewMemoryStore() *MemoryStore {
func (s *MemoryStore) Get(key []byte) ([]byte, error) { func (s *MemoryStore) Get(key []byte) ([]byte, error) {
s.mut.RLock() s.mut.RLock()
defer s.mut.RUnlock() defer s.mut.RUnlock()
if val, ok := s.mem[makeKey(key)]; ok { if val, ok := s.mem[string(key)]; ok {
return val, nil return val, nil
} }
return nil, ErrKeyNotFound return nil, ErrKeyNotFound
@ -52,7 +50,7 @@ func (s *MemoryStore) Get(key []byte) ([]byte, error) {
// Put implements the Store interface. Never returns an error. // Put implements the Store interface. Never returns an error.
func (s *MemoryStore) Put(key, value []byte) error { func (s *MemoryStore) Put(key, value []byte) error {
s.mut.Lock() s.mut.Lock()
s.mem[makeKey(key)] = value s.mem[string(key)] = value
s.mut.Unlock() s.mut.Unlock()
return nil return nil
} }
@ -61,7 +59,7 @@ func (s *MemoryStore) Put(key, value []byte) error {
func (s *MemoryStore) PutBatch(batch Batch) error { func (s *MemoryStore) PutBatch(batch Batch) error {
b := batch.(*MemoryBatch) b := batch.(*MemoryBatch)
for k, v := range b.m { for k, v := range b.m {
_ = s.Put(*k, v) _ = s.Put([]byte(k), v)
} }
return nil return nil
} }
@ -69,9 +67,8 @@ func (s *MemoryStore) PutBatch(batch Batch) error {
// Seek implements the Store interface. // Seek implements the Store interface.
func (s *MemoryStore) Seek(key []byte, f func(k, v []byte)) { func (s *MemoryStore) Seek(key []byte, f func(k, v []byte)) {
for k, v := range s.mem { for k, v := range s.mem {
if strings.Contains(k, hex.EncodeToString(key)) { if strings.HasPrefix(k, string(key)) {
decodeString, _ := hex.DecodeString(k) f([]byte(k), v)
f(decodeString, v)
} }
} }
} }
@ -84,7 +81,7 @@ func (s *MemoryStore) Batch() Batch {
// newMemoryBatch returns new memory batch. // newMemoryBatch returns new memory batch.
func newMemoryBatch() *MemoryBatch { func newMemoryBatch() *MemoryBatch {
return &MemoryBatch{ return &MemoryBatch{
m: make(map[*[]byte][]byte), m: make(map[string][]byte),
} }
} }
@ -96,8 +93,7 @@ func (s *MemoryStore) Persist(ps Store) (int, error) {
batch := ps.Batch() batch := ps.Batch()
keys := 0 keys := 0
for k, v := range s.mem { for k, v := range s.mem {
kb, _ := hex.DecodeString(k) batch.Put([]byte(k), v)
batch.Put(kb, v)
keys++ keys++
} }
var err error var err error
@ -118,7 +114,3 @@ func (s *MemoryStore) Close() error {
s.mut.Unlock() s.mut.Unlock()
return nil return nil
} }
func makeKey(k []byte) string {
return hex.EncodeToString(k)
}

View file

@ -58,7 +58,7 @@ func (s *RedisStore) Put(k, v []byte) error {
func (s *RedisStore) PutBatch(b Batch) error { func (s *RedisStore) PutBatch(b Batch) error {
pipe := s.client.Pipeline() pipe := s.client.Pipeline()
for k, v := range b.(*MemoryBatch).m { for k, v := range b.(*MemoryBatch).m {
pipe.Set(string(*k), v, 0) pipe.Set(k, v, 0)
} }
_, err := pipe.Exec() _, err := pipe.Exec()
return err return err

View file

@ -23,14 +23,6 @@ func TestNewRedisStore(t *testing.T) {
redisMock.Close() redisMock.Close()
} }
func TestRedisBatch_Len(t *testing.T) {
want := len(map[string]string{})
b := &MemoryBatch{
m: map[*[]byte][]byte{},
}
assert.Equal(t, len(b.m), want)
}
func TestRedisStore_GetAndPut(t *testing.T) { func TestRedisStore_GetAndPut(t *testing.T) {
prepareRedisMock(t) prepareRedisMock(t)
type args struct { type args struct {
@ -82,7 +74,7 @@ func TestRedisStore_GetAndPut(t *testing.T) {
} }
func TestRedisStore_PutBatch(t *testing.T) { func TestRedisStore_PutBatch(t *testing.T) {
batch := &MemoryBatch{m: map[*[]byte][]byte{&[]byte{'f', 'o', 'o', '1'}: []byte("bar1")}} batch := &MemoryBatch{m: map[string][]byte{"foo1": []byte("bar1")}}
mock, redisStore := prepareRedisMock(t) mock, redisStore := prepareRedisMock(t)
err := redisStore.PutBatch(batch) err := redisStore.PutBatch(batch)
assert.Nil(t, err, "Error while PutBatch") assert.Nil(t, err, "Error while PutBatch")