diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index ab72791fc..20b6ca120 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -86,7 +86,7 @@ type headersOpFunc func(headerList *HeaderHashList) func NewBlockchain(s storage.Store, cfg config.ProtocolConfiguration) (*Blockchain, error) { bc := &Blockchain{ config: cfg, - dao: &dao{store: storage.NewMemCachedStore(s)}, + dao: newDao(s), headersOp: make(chan headersOpFunc), headersOpDone: make(chan struct{}), stopCh: make(chan struct{}), @@ -344,7 +344,7 @@ func (bc *Blockchain) processHeader(h *Header, batch storage.Batch, headerList * // is happening here, quite allot as you can see :). If things are wired together // and all tests are in place, we can make a more optimized and cleaner implementation. func (bc *Blockchain) storeBlock(block *Block) error { - cache := &dao{store: storage.NewMemCachedStore(bc.dao.store)} + cache := newCachedDao(bc.dao.store) if err := cache.StoreAsBlock(block, 0); err != nil { return err } @@ -505,7 +505,7 @@ func (bc *Blockchain) storeBlock(block *Block) error { v.LoadScript(t.Script) err := v.Run() if !v.HasFailed() { - _, err := systemInterop.dao.store.Persist() + _, err := systemInterop.dao.Persist() if err != nil { return errors.Wrap(err, "failed to persist invocation results") } @@ -554,7 +554,7 @@ func (bc *Blockchain) storeBlock(block *Block) error { } } } - _, err := cache.store.Persist() + _, err := cache.Persist() if err != nil { return err } @@ -567,7 +567,7 @@ func (bc *Blockchain) storeBlock(block *Block) error { } // processOutputs processes transaction outputs. -func processOutputs(tx *transaction.Transaction, dao *dao) error { +func processOutputs(tx *transaction.Transaction, dao *cachedDao) error { for index, output := range tx.Outputs { account, err := dao.GetAccountStateOrNew(output.ScriptHash) if err != nil { @@ -588,7 +588,7 @@ func processOutputs(tx *transaction.Transaction, dao *dao) error { return nil } -func processTXWithValidatorsAdd(output *transaction.Output, account *state.Account, dao *dao) error { +func processTXWithValidatorsAdd(output *transaction.Output, account *state.Account, dao *cachedDao) error { if output.AssetID.Equals(governingTokenTX().Hash()) && len(account.Votes) > 0 { for _, vote := range account.Votes { validatorState, err := dao.GetValidatorStateOrNew(vote) @@ -604,7 +604,7 @@ func processTXWithValidatorsAdd(output *transaction.Output, account *state.Accou return nil } -func processTXWithValidatorsSubtract(account *state.Account, dao *dao, toSubtract util.Fixed8) error { +func processTXWithValidatorsSubtract(account *state.Account, dao *cachedDao, toSubtract util.Fixed8) error { for _, vote := range account.Votes { validator, err := dao.GetValidatorStateOrNew(vote) if err != nil { @@ -624,7 +624,7 @@ func processTXWithValidatorsSubtract(account *state.Account, dao *dao, toSubtrac return nil } -func processValidatorStateDescriptor(descriptor *transaction.StateDescriptor, dao *dao) error { +func processValidatorStateDescriptor(descriptor *transaction.StateDescriptor, dao *cachedDao) error { publicKey := &keys.PublicKey{} err := publicKey.DecodeBytes(descriptor.Key) if err != nil { @@ -645,7 +645,7 @@ func processValidatorStateDescriptor(descriptor *transaction.StateDescriptor, da return nil } -func processAccountStateDescriptor(descriptor *transaction.StateDescriptor, dao *dao) error { +func processAccountStateDescriptor(descriptor *transaction.StateDescriptor, dao *cachedDao) error { hash, err := util.Uint160DecodeBytesBE(descriptor.Key) if err != nil { return err @@ -690,7 +690,7 @@ func (bc *Blockchain) persist() error { err error ) - persisted, err = bc.dao.store.Persist() + persisted, err = bc.dao.Persist() if err != nil { return err } @@ -1156,7 +1156,7 @@ func (bc *Blockchain) GetStandByValidators() (keys.PublicKeys, error) { // GetValidators returns validators. // Golang implementation of GetValidators method in C# (https://github.com/neo-project/neo/blob/c64748ecbac3baeb8045b16af0d518398a6ced24/neo/Persistence/Snapshot.cs#L182) func (bc *Blockchain) GetValidators(txes ...*transaction.Transaction) ([]*keys.PublicKey, error) { - cache := &dao{store: storage.NewMemCachedStore(bc.dao.store)} + cache := newCachedDao(bc.dao.store) if len(txes) > 0 { for _, tx := range txes { // iterate through outputs @@ -1249,14 +1249,10 @@ func (bc *Blockchain) GetValidators(txes ...*transaction.Transaction) ([]*keys.P for i := 0; i < uniqueSBValidators.Len() && result.Len() < count; i++ { result = append(result, uniqueSBValidators[i]) } - _, err = cache.store.Persist() - if err != nil { - return nil, err - } return result, nil } -func processStateTX(dao *dao, tx *transaction.StateTX) error { +func processStateTX(dao *cachedDao, tx *transaction.StateTX) error { for _, desc := range tx.Descriptors { switch desc.Type { case transaction.Account: @@ -1272,7 +1268,7 @@ func processStateTX(dao *dao, tx *transaction.StateTX) error { return nil } -func processEnrollmentTX(dao *dao, tx *transaction.EnrollmentTX) error { +func processEnrollmentTX(dao *cachedDao, tx *transaction.EnrollmentTX) error { validatorState, err := dao.GetValidatorStateOrNew(&tx.PublicKey) if err != nil { return err @@ -1340,8 +1336,8 @@ func (bc *Blockchain) GetScriptHashesForVerifying(t *transaction.Transaction) ([ func (bc *Blockchain) spawnVMWithInterops(interopCtx *interopContext) *vm.VM { vm := vm.New() vm.SetScriptGetter(func(hash util.Uint160) []byte { - cs := bc.GetContractState(hash) - if cs == nil { + cs, err := interopCtx.dao.GetContractState(hash) + if err != nil { return nil } return cs.Script diff --git a/pkg/core/cacheddao.go b/pkg/core/cacheddao.go new file mode 100644 index 000000000..d58acdedf --- /dev/null +++ b/pkg/core/cacheddao.go @@ -0,0 +1,82 @@ +package core + +import ( + "github.com/CityOfZion/neo-go/pkg/core/state" + "github.com/CityOfZion/neo-go/pkg/core/storage" + "github.com/CityOfZion/neo-go/pkg/util" +) + +// cachedDao is a data access object that mimics dao, but has a write cache +// for accounts and read cache for contracts. These are the most frequently used +// objects in the storeBlock(). +type cachedDao struct { + dao + accounts map[util.Uint160]*state.Account + contracts map[util.Uint160]*state.Contract +} + +// newCachedDao returns new cachedDao wrapping around given backing store. +func newCachedDao(backend storage.Store) *cachedDao { + accs := make(map[util.Uint160]*state.Account) + ctrs := make(map[util.Uint160]*state.Contract) + return &cachedDao{*newDao(backend), accs, ctrs} +} + +// GetAccountStateOrNew retrieves Account from cache or underlying Store +// or creates a new one if it doesn't exist. +func (cd *cachedDao) GetAccountStateOrNew(hash util.Uint160) (*state.Account, error) { + if cd.accounts[hash] != nil { + return cd.accounts[hash], nil + } + return cd.dao.GetAccountStateOrNew(hash) +} + +// GetAccountState retrieves Account from cache or underlying Store. +func (cd *cachedDao) GetAccountState(hash util.Uint160) (*state.Account, error) { + if cd.accounts[hash] != nil { + return cd.accounts[hash], nil + } + return cd.dao.GetAccountState(hash) +} + +// PutAccountState saves given Account in the cache. +func (cd *cachedDao) PutAccountState(as *state.Account) error { + cd.accounts[as.ScriptHash] = as + return nil +} + +// GetContractState returns contract state from cache or underlying Store. +func (cd *cachedDao) GetContractState(hash util.Uint160) (*state.Contract, error) { + if cd.contracts[hash] != nil { + return cd.contracts[hash], nil + } + cs, err := cd.dao.GetContractState(hash) + if err == nil { + cd.contracts[hash] = cs + } + return cs, err +} + +// PutContractState puts given contract state into the given store. +func (cd *cachedDao) PutContractState(cs *state.Contract) error { + cd.contracts[cs.ScriptHash()] = cs + return cd.dao.PutContractState(cs) +} + +// DeleteContractState deletes given contract state in cache and backing Store. +func (cd *cachedDao) DeleteContractState(hash util.Uint160) error { + cd.contracts[hash] = nil + return cd.dao.DeleteContractState(hash) +} + +// Persist flushes all the changes made into the (supposedly) persistent +// underlying store. +func (cd *cachedDao) Persist() (int, error) { + for sc := range cd.accounts { + err := cd.dao.PutAccountState(cd.accounts[sc]) + if err != nil { + return 0, err + } + } + return cd.dao.Persist() +} diff --git a/pkg/core/cacheddao_test.go b/pkg/core/cacheddao_test.go new file mode 100644 index 000000000..9020fc73e --- /dev/null +++ b/pkg/core/cacheddao_test.go @@ -0,0 +1,83 @@ +package core + +import ( + "testing" + + "github.com/CityOfZion/neo-go/pkg/core/state" + "github.com/CityOfZion/neo-go/pkg/core/storage" + "github.com/CityOfZion/neo-go/pkg/crypto/hash" + "github.com/CityOfZion/neo-go/pkg/internal/random" + "github.com/CityOfZion/neo-go/pkg/smartcontract" + "github.com/stretchr/testify/require" +) + +func TestCachedDaoAccounts(t *testing.T) { + store := storage.NewMemoryStore() + // Persistent DAO to check for backing storage. + pdao := newDao(store) + // Cached DAO. + cdao := newCachedDao(store) + + hash := random.Uint160() + _, err := cdao.GetAccountState(hash) + require.NotNil(t, err) + + acc, err := cdao.GetAccountStateOrNew(hash) + require.Nil(t, err) + _, err = pdao.GetAccountState(hash) + require.NotNil(t, err) + + acc.Version = 42 + require.NoError(t, cdao.PutAccountState(acc)) + _, err = pdao.GetAccountState(hash) + require.NotNil(t, err) + + acc2, err := cdao.GetAccountState(hash) + require.Nil(t, err) + require.Equal(t, acc, acc2) + + acc2, err = cdao.GetAccountStateOrNew(hash) + require.Nil(t, err) + require.Equal(t, acc, acc2) + + _, err = cdao.Persist() + require.Nil(t, err) + + acct, err := pdao.GetAccountState(hash) + require.Nil(t, err) + require.Equal(t, acc, acct) +} + +func TestCachedDaoContracts(t *testing.T) { + store := storage.NewMemoryStore() + dao := newCachedDao(store) + + script := []byte{0xde, 0xad, 0xbe, 0xef} + sh := hash.Hash160(script) + _, err := dao.GetContractState(sh) + require.NotNil(t, err) + + cs := &state.Contract{} + cs.Name = "test" + cs.Script = script + cs.ParamList = []smartcontract.ParamType{1, 2} + + require.NoError(t, dao.PutContractState(cs)) + cs2, err := dao.GetContractState(sh) + require.Nil(t, err) + require.Equal(t, cs, cs2) + + _, err = dao.Persist() + require.Nil(t, err) + dao2 := newCachedDao(store) + cs2, err = dao2.GetContractState(sh) + require.Nil(t, err) + require.Equal(t, cs, cs2) + + require.NoError(t, dao.DeleteContractState(sh)) + cs2, err = dao2.GetContractState(sh) + require.Nil(t, err) + require.Equal(t, cs, cs2) + _, err = dao.GetContractState(sh) + require.NotNil(t, err) +} diff --git a/pkg/core/dao.go b/pkg/core/dao.go index 1a0c8cea3..d570d5e3e 100644 --- a/pkg/core/dao.go +++ b/pkg/core/dao.go @@ -19,6 +19,10 @@ type dao struct { store *storage.MemCachedStore } +func newDao(backend storage.Store) *dao { + return &dao{store: storage.NewMemCachedStore(backend)} +} + // GetAndDecode performs get operation and decoding with serializable structures. func (dao *dao) GetAndDecode(entity io.Serializable, key []byte) error { entityBytes, err := dao.store.Get(key) @@ -51,9 +55,6 @@ func (dao *dao) GetAccountStateOrNew(hash util.Uint160) (*state.Account, error) return nil, err } account = state.NewAccount(hash) - if err = dao.PutAccountState(account); err != nil { - return nil, err - } } return account, nil } @@ -147,9 +148,6 @@ func (dao *dao) GetUnspentCoinStateOrNew(hash util.Uint256) (*UnspentCoinState, unspent = &UnspentCoinState{ states: []state.Coin{}, } - if err = dao.PutUnspentCoinState(hash, unspent); err != nil { - return nil, err - } } return unspent, nil } @@ -185,9 +183,6 @@ func (dao *dao) GetSpentCoinsOrNew(hash util.Uint256) (*SpentCoinState, error) { spent = &SpentCoinState{ items: make(map[uint16]uint32), } - if err = dao.PutSpentCoinState(hash, spent); err != nil { - return nil, err - } } return spent, nil } @@ -227,9 +222,6 @@ func (dao *dao) GetValidatorStateOrNew(publicKey *keys.PublicKey) (*state.Valida return nil, err } validatorState = &state.Validator{PublicKey: publicKey} - if err = dao.PutValidatorState(validatorState); err != nil { - return nil, err - } } return validatorState, nil @@ -551,3 +543,9 @@ func (dao *dao) IsDoubleSpend(tx *transaction.Transaction) bool { } return false } + +// Persist flushes all the changes made into the (supposedly) persistent +// underlying store. +func (dao *dao) Persist() (int, error) { + return dao.store.Persist() +} diff --git a/pkg/core/dao_test.go b/pkg/core/dao_test.go index 5aa44fe9f..d747ddf19 100644 --- a/pkg/core/dao_test.go +++ b/pkg/core/dao_test.go @@ -15,7 +15,7 @@ import ( ) func TestPutGetAndDecode(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) serializable := &TestSerializable{field: random.String(4)} hash := []byte{1} err := dao.Put(serializable, hash) @@ -40,18 +40,15 @@ func (t *TestSerializable) DecodeBinary(reader *io.BinReader) { } func TestGetAccountStateOrNew_New(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint160() createdAccount, err := dao.GetAccountStateOrNew(hash) require.NoError(t, err) require.NotNil(t, createdAccount) - gotAccount, err := dao.GetAccountState(hash) - require.NoError(t, err) - require.Equal(t, createdAccount, gotAccount) } func TestPutAndGetAccountStateOrNew(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint160() accountState := &state.Account{ScriptHash: hash} err := dao.PutAccountState(accountState) @@ -62,7 +59,7 @@ func TestPutAndGetAccountStateOrNew(t *testing.T) { } func TestPutAndGetAssetState(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) id := random.Uint256() assetState := &state.Asset{ID: id, Owner: keys.PublicKey{}} err := dao.PutAssetState(assetState) @@ -73,8 +70,8 @@ func TestPutAndGetAssetState(t *testing.T) { } func TestPutAndGetContractState(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} - contractState := &state.Contract{Script: []byte{}, ParamList:[]smartcontract.ParamType{}} + dao := newDao(storage.NewMemoryStore()) + contractState := &state.Contract{Script: []byte{}, ParamList: []smartcontract.ParamType{}} hash := contractState.ScriptHash() err := dao.PutContractState(contractState) require.NoError(t, err) @@ -84,8 +81,8 @@ func TestPutAndGetContractState(t *testing.T) { } func TestDeleteContractState(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} - contractState := &state.Contract{Script: []byte{}, ParamList:[]smartcontract.ParamType{}} + dao := newDao(storage.NewMemoryStore()) + contractState := &state.Contract{Script: []byte{}, ParamList: []smartcontract.ParamType{}} hash := contractState.ScriptHash() err := dao.PutContractState(contractState) require.NoError(t, err) @@ -97,18 +94,15 @@ func TestDeleteContractState(t *testing.T) { } func TestGetUnspentCoinStateOrNew_New(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint256() unspentCoinState, err := dao.GetUnspentCoinStateOrNew(hash) require.NoError(t, err) require.NotNil(t, unspentCoinState) - gotUnspentCoinState, err := dao.GetUnspentCoinState(hash) - require.NoError(t, err) - require.Equal(t, unspentCoinState, gotUnspentCoinState) } func TestGetUnspentCoinState_Err(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint256() gotUnspentCoinState, err := dao.GetUnspentCoinState(hash) require.Error(t, err) @@ -116,9 +110,9 @@ func TestGetUnspentCoinState_Err(t *testing.T) { } func TestPutGetUnspentCoinState(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint256() - unspentCoinState := &UnspentCoinState{states:[]state.Coin{}} + unspentCoinState := &UnspentCoinState{states: []state.Coin{}} err := dao.PutUnspentCoinState(hash, unspentCoinState) require.NoError(t, err) gotUnspentCoinState, err := dao.GetUnspentCoinState(hash) @@ -127,20 +121,17 @@ func TestPutGetUnspentCoinState(t *testing.T) { } func TestGetSpentCoinStateOrNew_New(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint256() spentCoinState, err := dao.GetSpentCoinsOrNew(hash) require.NoError(t, err) require.NotNil(t, spentCoinState) - gotSpentCoinState, err := dao.GetSpentCoinState(hash) - require.NoError(t, err) - require.Equal(t, spentCoinState, gotSpentCoinState) } func TestPutAndGetSpentCoinState(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint256() - spentCoinState := &SpentCoinState{items:make(map[uint16]uint32)} + spentCoinState := &SpentCoinState{items: make(map[uint16]uint32)} err := dao.PutSpentCoinState(hash, spentCoinState) require.NoError(t, err) gotSpentCoinState, err := dao.GetSpentCoinState(hash) @@ -149,7 +140,7 @@ func TestPutAndGetSpentCoinState(t *testing.T) { } func TestGetSpentCoinState_Err(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint256() spentCoinState, err := dao.GetSpentCoinState(hash) require.Error(t, err) @@ -157,9 +148,9 @@ func TestGetSpentCoinState_Err(t *testing.T) { } func TestDeleteSpentCoinState(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint256() - spentCoinState := &SpentCoinState{items:make(map[uint16]uint32)} + spentCoinState := &SpentCoinState{items: make(map[uint16]uint32)} err := dao.PutSpentCoinState(hash, spentCoinState) require.NoError(t, err) err = dao.DeleteSpentCoinState(hash) @@ -170,18 +161,15 @@ func TestDeleteSpentCoinState(t *testing.T) { } func TestGetValidatorStateOrNew_New(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) publicKey := &keys.PublicKey{} validatorState, err := dao.GetValidatorStateOrNew(publicKey) require.NoError(t, err) require.NotNil(t, validatorState) - gotValidatorState, err := dao.GetValidatorState(publicKey) - require.NoError(t, err) - require.Equal(t, validatorState, gotValidatorState) } func TestPutGetValidatorState(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) publicKey := &keys.PublicKey{} validatorState := &state.Validator{ PublicKey: publicKey, @@ -196,7 +184,7 @@ func TestPutGetValidatorState(t *testing.T) { } func TestDeleteValidatorState(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) publicKey := &keys.PublicKey{} validatorState := &state.Validator{ PublicKey: publicKey, @@ -213,7 +201,7 @@ func TestDeleteValidatorState(t *testing.T) { } func TestGetValidators(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) publicKey := &keys.PublicKey{} validatorState := &state.Validator{ PublicKey: publicKey, @@ -228,9 +216,9 @@ func TestGetValidators(t *testing.T) { } func TestPutGetAppExecResult(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint256() - appExecResult := &state.AppExecResult{TxHash: hash, Events:[]state.NotificationEvent{}} + appExecResult := &state.AppExecResult{TxHash: hash, Events: []state.NotificationEvent{}} err := dao.PutAppExecResult(appExecResult) require.NoError(t, err) gotAppExecResult, err := dao.GetAppExecResult(hash) @@ -239,7 +227,7 @@ func TestPutGetAppExecResult(t *testing.T) { } func TestPutGetStorageItem(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint160() key := []byte{0} storageItem := &state.StorageItem{Value: []uint8{}} @@ -250,7 +238,7 @@ func TestPutGetStorageItem(t *testing.T) { } func TestDeleteStorageItem(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint160() key := []byte{0} storageItem := &state.StorageItem{Value: []uint8{}} @@ -263,7 +251,7 @@ func TestDeleteStorageItem(t *testing.T) { } func TestGetBlock_NotExists(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) hash := random.Uint256() block, err := dao.GetBlock(hash) require.Error(t, err) @@ -271,7 +259,7 @@ func TestGetBlock_NotExists(t *testing.T) { } func TestPutGetBlock(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) block := &Block{ BlockBase: BlockBase{ Script: transaction.Witness{ @@ -289,14 +277,14 @@ func TestPutGetBlock(t *testing.T) { } func TestGetVersion_NoVersion(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) version, err := dao.GetVersion() require.Error(t, err) require.Equal(t, "", version) } func TestGetVersion(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) err := dao.PutVersion("testVersion") require.NoError(t, err) version, err := dao.GetVersion() @@ -305,14 +293,14 @@ func TestGetVersion(t *testing.T) { } func TestGetCurrentHeaderHeight_NoHeader(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) height, err := dao.GetCurrentBlockHeight() require.Error(t, err) require.Equal(t, uint32(0), height) } func TestGetCurrentHeaderHeight_Store(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) block := &Block{ BlockBase: BlockBase{ Script: transaction.Witness{ @@ -329,7 +317,7 @@ func TestGetCurrentHeaderHeight_Store(t *testing.T) { } func TestStoreAsTransaction(t *testing.T) { - dao := &dao{store: storage.NewMemCachedStore(storage.NewMemoryStore())} + dao := newDao(storage.NewMemoryStore()) tx := &transaction.Transaction{} hash := tx.Hash() err := dao.StoreAsTransaction(tx, 0) diff --git a/pkg/core/interop_neo.go b/pkg/core/interop_neo.go index 31576d95a..44c55b225 100644 --- a/pkg/core/interop_neo.go +++ b/pkg/core/interop_neo.go @@ -172,8 +172,8 @@ func (ic *interopContext) txGetUnspentCoins(v *vm.VM) error { if !ok { return errors.New("value is not a transaction") } - ucs := ic.bc.GetUnspentCoinState(tx.Hash()) - if ucs == nil { + ucs, err := ic.dao.GetUnspentCoinState(tx.Hash()) + if err != nil { return errors.New("no unspent coin state found") } v.Estack().PushVal(vm.NewInteropItem(ucs)) @@ -200,10 +200,7 @@ func (ic *interopContext) txGetWitnesses(v *vm.VM) error { // bcGetValidators returns validators. func (ic *interopContext) bcGetValidators(v *vm.VM) error { - validators, err := ic.bc.GetValidators() - if err != nil { - return err - } + validators := ic.dao.GetValidators() v.Estack().PushVal(validators) return nil } @@ -315,9 +312,9 @@ func (ic *interopContext) bcGetAccount(v *vm.VM) error { if err != nil { return err } - acc := ic.bc.GetAccountState(acchash) - if acc == nil { - acc = state.NewAccount(acchash) + acc, err := ic.dao.GetAccountStateOrNew(acchash) + if err != nil { + return err } v.Estack().PushVal(vm.NewInteropItem(acc)) return nil @@ -330,8 +327,8 @@ func (ic *interopContext) bcGetAsset(v *vm.VM) error { if err != nil { return err } - as := ic.bc.GetAssetState(ashash) - if as == nil { + as, err := ic.dao.GetAssetState(ashash) + if err != nil { return errors.New("asset not found") } v.Estack().PushVal(vm.NewInteropItem(as)) @@ -394,8 +391,8 @@ func (ic *interopContext) accountIsStandard(v *vm.VM) error { if err != nil { return err } - contract := ic.bc.GetContractState(acchash) - res := contract == nil || vm.IsStandardContract(contract.Script) + contract, err := ic.dao.GetContractState(acchash) + res := err != nil || vm.IsStandardContract(contract.Script) v.Estack().PushVal(res) return nil } @@ -413,7 +410,7 @@ func (ic *interopContext) storageFind(v *vm.VM) error { return err } prefix := string(v.Estack().Pop().Bytes()) - siMap, err := ic.bc.GetStorageItems(stc.ScriptHash) + siMap, err := ic.dao.GetStorageItems(stc.ScriptHash) if err != nil { return err } @@ -488,8 +485,8 @@ func (ic *interopContext) contractCreate(v *vm.VM) error { if err != nil { return nil } - contract := ic.bc.GetContractState(newcontract.ScriptHash()) - if contract == nil { + contract, err := ic.dao.GetContractState(newcontract.ScriptHash()) + if err != nil { contract = newcontract err := ic.dao.PutContractState(contract) if err != nil { @@ -528,8 +525,8 @@ func (ic *interopContext) contractMigrate(v *vm.VM) error { if err != nil { return nil } - contract := ic.bc.GetContractState(newcontract.ScriptHash()) - if contract == nil { + contract, err := ic.dao.GetContractState(newcontract.ScriptHash()) + if err != nil { contract = newcontract err := ic.dao.PutContractState(contract) if err != nil { @@ -537,7 +534,7 @@ func (ic *interopContext) contractMigrate(v *vm.VM) error { } if contract.HasStorage() { hash := getContextScriptHash(v, 0) - siMap, err := ic.bc.GetStorageItems(hash) + siMap, err := ic.dao.GetStorageItems(hash) if err != nil { return err } @@ -729,8 +726,8 @@ func (ic *interopContext) assetRenew(v *vm.VM) error { } years := byte(v.Estack().Pop().BigInt().Int64()) // Not sure why C# code regets an asset from the Store, but we also do it. - asset := ic.bc.GetAssetState(as.ID) - if asset == nil { + asset, err := ic.dao.GetAssetState(as.ID) + if err != nil { return errors.New("can't renew non-existent asset") } if asset.Expiration < ic.bc.BlockHeight()+1 { @@ -741,7 +738,7 @@ func (ic *interopContext) assetRenew(v *vm.VM) error { expiration = math.MaxUint32 } asset.Expiration = uint32(expiration) - err := ic.dao.PutAssetState(asset) + err = ic.dao.PutAssetState(asset) if err != nil { return gherr.Wrap(err, "failed to store asset") } diff --git a/pkg/core/interop_system.go b/pkg/core/interop_system.go index a399e67e7..91f919d32 100644 --- a/pkg/core/interop_system.go +++ b/pkg/core/interop_system.go @@ -68,8 +68,8 @@ func (ic *interopContext) bcGetContract(v *vm.VM) error { if err != nil { return err } - cs := ic.bc.GetContractState(hash) - if cs == nil { + cs, err := ic.dao.GetContractState(hash) + if err != nil { v.Estack().PushVal([]byte{}) } else { v.Estack().PushVal(vm.NewInteropItem(cs)) @@ -100,18 +100,18 @@ func (ic *interopContext) bcGetHeight(v *vm.VM) error { // getTransactionAndHeight gets parameter from the vm evaluation stack and // returns transaction and its height if it's present in the blockchain. -func getTransactionAndHeight(bc Blockchainer, v *vm.VM) (*transaction.Transaction, uint32, error) { +func getTransactionAndHeight(cd *cachedDao, v *vm.VM) (*transaction.Transaction, uint32, error) { hashbytes := v.Estack().Pop().Bytes() hash, err := util.Uint256DecodeBytesLE(hashbytes) if err != nil { return nil, 0, err } - return bc.GetTransaction(hash) + return cd.GetTransaction(hash) } // bcGetTransaction returns transaction. func (ic *interopContext) bcGetTransaction(v *vm.VM) error { - tx, _, err := getTransactionAndHeight(ic.bc, v) + tx, _, err := getTransactionAndHeight(ic.dao, v) if err != nil { return err } @@ -121,7 +121,7 @@ func (ic *interopContext) bcGetTransaction(v *vm.VM) error { // bcGetTransactionHeight returns transaction height. func (ic *interopContext) bcGetTransactionHeight(v *vm.VM) error { - _, h, err := getTransactionAndHeight(ic.bc, v) + _, h, err := getTransactionAndHeight(ic.dao, v) if err != nil { return err } @@ -254,7 +254,7 @@ func (ic *interopContext) engineGetScriptContainer(v *vm.VM) error { func getContextScriptHash(v *vm.VM, n int) util.Uint160 { ctxIface := v.Istack().Peek(n).Value() ctx := ctxIface.(*vm.Context) - return hash.Hash160(ctx.Program()) + return ctx.ScriptHash() } // pushContextScriptHash pushes to evaluation stack the script hash of the @@ -383,8 +383,8 @@ func (ic *interopContext) runtimeDeserialize(v *vm.VM) error { } */ func (ic *interopContext) checkStorageContext(stc *StorageContext) error { - contract := ic.bc.GetContractState(stc.ScriptHash) - if contract == nil { + contract, err := ic.dao.GetContractState(stc.ScriptHash) + if err != nil { return errors.New("no contract found") } if !contract.HasStorage() { @@ -535,16 +535,16 @@ func (ic *interopContext) contractDestroy(v *vm.VM) error { return errors.New("can't destroy contract when not triggered by application") } hash := getContextScriptHash(v, 0) - cs := ic.bc.GetContractState(hash) - if cs == nil { + cs, err := ic.dao.GetContractState(hash) + if err != nil { return nil } - err := ic.dao.DeleteContractState(hash) + err = ic.dao.DeleteContractState(hash) if err != nil { return err } if cs.HasStorage() { - siMap, err := ic.bc.GetStorageItems(hash) + siMap, err := ic.dao.GetStorageItems(hash) if err != nil { return err } diff --git a/pkg/core/interops.go b/pkg/core/interops.go index 33fac36c6..f249e0299 100644 --- a/pkg/core/interops.go +++ b/pkg/core/interops.go @@ -19,12 +19,12 @@ type interopContext struct { trigger byte block *Block tx *transaction.Transaction - dao *dao + dao *cachedDao notifications []state.NotificationEvent } func newInteropContext(trigger byte, bc Blockchainer, s storage.Store, block *Block, tx *transaction.Transaction) *interopContext { - dao := &dao{store: storage.NewMemCachedStore(s)} + dao := newCachedDao(s) nes := make([]state.NotificationEvent, 0) return &interopContext{bc, trigger, block, tx, dao, nes} } diff --git a/pkg/core/mem_pool.go b/pkg/core/mem_pool.go index d9e489c0f..a01793a5d 100644 --- a/pkg/core/mem_pool.go +++ b/pkg/core/mem_pool.go @@ -269,12 +269,15 @@ func min(sortedPool PoolItems) *PoolItem { // GetVerifiedTransactions returns a slice of Input from all the transactions in the memory pool // whose hash is not included in excludedHashes. func (mp *MemPool) GetVerifiedTransactions() []*transaction.Transaction { - var t []*transaction.Transaction - mp.lock.RLock() defer mp.lock.RUnlock() + + var t = make([]*transaction.Transaction, len(mp.unsortedTxn)) + var i int + for _, p := range mp.unsortedTxn { - t = append(t, p.txn) + t[i] = p.txn + i++ } return t diff --git a/pkg/core/state/account.go b/pkg/core/state/account.go index d1a016d6c..12522b0d4 100644 --- a/pkg/core/state/account.go +++ b/pkg/core/state/account.go @@ -49,8 +49,11 @@ func (s *Account) DecodeBinary(br *io.BinReader) { for i := 0; i < int(lenBalances); i++ { key := util.Uint256{} br.ReadBytes(key[:]) - ubs := make([]UnspentBalance, 0) - br.ReadArray(&ubs) + len := int(br.ReadVarUint()) + ubs := make([]UnspentBalance, len) + for j := 0; j < len; j++ { + ubs[j].DecodeBinary(br) + } s.Balances[key] = ubs } } @@ -65,7 +68,10 @@ func (s *Account) EncodeBinary(bw *io.BinWriter) { bw.WriteVarUint(uint64(len(s.Balances))) for k, v := range s.Balances { bw.WriteBytes(k[:]) - bw.WriteArray(v) + bw.WriteVarUint(uint64(len(v))) + for i := range v { + v[i].EncodeBinary(bw) + } } } diff --git a/pkg/vm/context.go b/pkg/vm/context.go index bf42b4440..f1d5899df 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -1,9 +1,11 @@ package vm import ( + "encoding/binary" "errors" - "github.com/CityOfZion/neo-go/pkg/io" + "github.com/CityOfZion/neo-go/pkg/crypto/hash" + "github.com/CityOfZion/neo-go/pkg/util" "github.com/CityOfZion/neo-go/pkg/vm/opcode" ) @@ -29,8 +31,13 @@ type Context struct { // Alt stack pointer. astack *Stack + + // Script hash of the prog. + scriptHash util.Uint160 } +var errNoInstParam = errors.New("failed to read instruction parameter") + // NewContext returns a new Context object. func NewContext(b []byte) *Context { return &Context{ @@ -44,33 +51,44 @@ func NewContext(b []byte) *Context { // its invocation the instruction pointer points to the instruction being // returned. func (c *Context) Next() (opcode.Opcode, []byte, error) { + var err error + c.ip = c.nextip if c.ip >= len(c.prog) { return opcode.RET, nil, nil } - r := io.NewBinReaderFromBuf(c.prog[c.ip:]) - var instrbyte = r.ReadB() + var instrbyte = c.prog[c.ip] instr := opcode.Opcode(instrbyte) c.nextip++ var numtoread int switch instr { case opcode.PUSHDATA1, opcode.SYSCALL: - var n = r.ReadB() - numtoread = int(n) - c.nextip++ - case opcode.PUSHDATA2: - var n = r.ReadU16LE() - numtoread = int(n) - c.nextip += 2 - case opcode.PUSHDATA4: - var n = r.ReadU32LE() - if n > MaxItemSize { - return instr, nil, errors.New("parameter is too big") + if c.nextip >= len(c.prog) { + err = errNoInstParam + } else { + numtoread = int(c.prog[c.nextip]) + c.nextip++ + } + case opcode.PUSHDATA2: + if c.nextip+1 >= len(c.prog) { + err = errNoInstParam + } else { + numtoread = int(binary.LittleEndian.Uint16(c.prog[c.nextip : c.nextip+2])) + c.nextip += 2 + } + case opcode.PUSHDATA4: + if c.nextip+3 >= len(c.prog) { + err = errNoInstParam + } else { + var n = binary.LittleEndian.Uint32(c.prog[c.nextip : c.nextip+4]) + if n > MaxItemSize { + return instr, nil, errors.New("parameter is too big") + } + numtoread = int(n) + c.nextip += 4 } - numtoread = int(n) - c.nextip += 4 case opcode.JMP, opcode.JMPIF, opcode.JMPIFNOT, opcode.CALL, opcode.CALLED, opcode.CALLEDT: numtoread = 2 case opcode.CALLI: @@ -87,11 +105,14 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { return instr, nil, nil } } - parameter := make([]byte, numtoread) - r.ReadBytes(parameter) - if r.Err != nil { - return instr, nil, errors.New("failed to read instruction parameter") + if c.nextip+numtoread-1 >= len(c.prog) { + err = errNoInstParam } + if err != nil { + return instr, nil, err + } + parameter := make([]byte, numtoread) + copy(parameter, c.prog[c.nextip:c.nextip+numtoread]) c.nextip += numtoread return instr, parameter, nil } @@ -125,6 +146,14 @@ func (c *Context) Program() []byte { return c.prog } +// ScriptHash returns a hash of the script in the current context. +func (c *Context) ScriptHash() util.Uint160 { + if c.scriptHash.Equals(util.Uint160{}) { + c.scriptHash = hash.Hash160(c.prog) + } + return c.scriptHash +} + // Value implements StackItem interface. func (c *Context) Value() interface{} { return c diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 0b488f83f..5d784e83e 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -129,8 +129,8 @@ func (v *VM) RegisterInteropFunc(name string, f InteropFunc, price int) { // the VM. Effectively it's a batched version of RegisterInteropFunc. func (v *VM) RegisterInteropFuncs(interops map[string]InteropFuncPrice) { // We allow reregistration here. - for name, funPrice := range interops { - v.interop[name] = funPrice + for name := range interops { + v.interop[name] = interops[name] } }