diff --git a/pkg/wallet/wallet.go b/pkg/wallet/wallet.go index 413a318fb..c54aa20cc 100644 --- a/pkg/wallet/wallet.go +++ b/pkg/wallet/wallet.go @@ -6,6 +6,7 @@ import ( "os" "github.com/CityOfZion/neo-go/pkg/crypto/keys" + "github.com/CityOfZion/neo-go/pkg/util" ) const ( @@ -117,3 +118,14 @@ func (w *Wallet) Close() { rc.Close() } } + +// GetAccount returns account corresponding to the provided scripthash. +func (w *Wallet) GetAccount(h util.Uint160) *Account { + for _, acc := range w.Accounts { + if c := acc.Contract; c != nil && h.Equals(c.ScriptHash()) { + return acc + } + } + + return nil +} diff --git a/pkg/wallet/wallet_test.go b/pkg/wallet/wallet_test.go index 2a80ca765..dca5590da 100644 --- a/pkg/wallet/wallet_test.go +++ b/pkg/wallet/wallet_test.go @@ -6,6 +6,7 @@ import ( "os" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -119,3 +120,28 @@ func removeWallet(t *testing.T, walletPath string) { err := os.RemoveAll(walletPath) require.NoError(t, err) } + +func TestWallet_GetAccount(t *testing.T) { + wallet := checkWalletConstructor(t) + accounts := []*Account{ + { + Contract: &Contract{ + Script: []byte{0, 1, 2, 3}, + }, + }, + { + Contract: &Contract{ + Script: []byte{3, 2, 1, 0}, + }, + }, + } + + for _, acc := range accounts { + wallet.AddAccount(acc) + } + + for i, acc := range accounts { + h := acc.Contract.ScriptHash() + assert.Equal(t, acc, wallet.GetAccount(h), "can't get %d account", i) + } +}