From 8f196c8222f356fbd8721108418d48a28d96cc21 Mon Sep 17 00:00:00 2001 From: Evgeniy Stratonikov Date: Thu, 29 Jul 2021 16:00:20 +0300 Subject: [PATCH 1/3] wallet: marshal before writing to file Signed-off-by: Evgeniy Stratonikov --- pkg/wallet/wallet.go | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/pkg/wallet/wallet.go b/pkg/wallet/wallet.go index c42e65ef1..a4694bf4a 100644 --- a/pkg/wallet/wallet.go +++ b/pkg/wallet/wallet.go @@ -143,20 +143,31 @@ func (w *Wallet) Path() string { // that is responsible for saving the data. This can // be a buffer, file, etc.. func (w *Wallet) Save() error { - if err := w.rewind(); err != nil { + data, err := json.Marshal(w) + if err != nil { return err } - return json.NewEncoder(w.rw).Encode(w) + + return w.writeRaw(data) } // savePretty saves wallet in a beautiful JSON. func (w *Wallet) savePretty() error { + data, err := json.MarshalIndent(w, "", " ") + if err != nil { + return err + } + + return w.writeRaw(data) +} + +func (w *Wallet) writeRaw(data []byte) error { if err := w.rewind(); err != nil { return err } - enc := json.NewEncoder(w.rw) - enc.SetIndent("", " ") - return enc.Encode(w) + + _, err := w.rw.Write(data) + return err } func (w *Wallet) rewind() error { From a429aa3e68eeae35a8813f8e4c42359077097860 Mon Sep 17 00:00:00 2001 From: Evgeniy Stratonikov Date: Thu, 29 Jul 2021 14:46:07 +0300 Subject: [PATCH 2/3] wallet: truncate file when writing If wallet size decreases, we need to remove trailing garbage if it exists. This can happen when removing account or reading pretty-printed wallet. It doesn't affect our CLI (we decode only file prefix), but it is nice to always have a valid JSON file. Signed-off-by: Evgeniy Stratonikov --- cli/wallet_test.go | 32 ++++++++++++++++++++++++++++++++ pkg/wallet/wallet.go | 11 ++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/cli/wallet_test.go b/cli/wallet_test.go index ddf57d05c..098660d42 100644 --- a/cli/wallet_test.go +++ b/cli/wallet_test.go @@ -19,6 +19,38 @@ import ( "github.com/stretchr/testify/require" ) +func TestWalletAccountRemove(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "neogo.test.walletinit") + require.NoError(t, err) + t.Cleanup(func() { + os.RemoveAll(tmpDir) + }) + + e := newExecutor(t, false) + + walletPath := path.Join(tmpDir, "wallet.json") + e.In.WriteString("acc1\r") + e.In.WriteString("pass\r") + e.In.WriteString("pass\r") + e.Run(t, "neo-go", "wallet", "init", "--wallet", walletPath, "--account") + + e.In.WriteString("acc2\r") + e.In.WriteString("pass\r") + e.In.WriteString("pass\r") + e.Run(t, "neo-go", "wallet", "create", "--wallet", walletPath) + + w, err := wallet.NewWalletFromFile(walletPath) + require.NoError(t, err) + + addr := w.Accounts[0].Address + e.Run(t, "neo-go", "wallet", "remove", "--wallet", walletPath, + "--address", addr, "--force") + + rawWallet, err := ioutil.ReadFile(walletPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(rawWallet, new(wallet.Wallet))) +} + func TestWalletInit(t *testing.T) { tmpDir, err := ioutil.TempDir("", "neogo.test.walletinit") require.NoError(t, err) diff --git a/pkg/wallet/wallet.go b/pkg/wallet/wallet.go index a4694bf4a..746466e22 100644 --- a/pkg/wallet/wallet.go +++ b/pkg/wallet/wallet.go @@ -167,7 +167,16 @@ func (w *Wallet) writeRaw(data []byte) error { } _, err := w.rw.Write(data) - return err + if err != nil { + return err + } + + if f, ok := w.rw.(*os.File); ok { + if err := f.Truncate(int64(len(data))); err != nil { + return err + } + } + return nil } func (w *Wallet) rewind() error { From 283173bb9de0c3da7bf61cef23298d96d2b051bc Mon Sep 17 00:00:00 2001 From: Evgeniy Stratonikov Date: Thu, 29 Jul 2021 14:59:23 +0300 Subject: [PATCH 3/3] wallet: use named constants in `Seek` Signed-off-by: Evgeniy Stratonikov --- pkg/wallet/wallet.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/wallet/wallet.go b/pkg/wallet/wallet.go index 746466e22..ca32a4790 100644 --- a/pkg/wallet/wallet.go +++ b/pkg/wallet/wallet.go @@ -181,7 +181,7 @@ func (w *Wallet) writeRaw(data []byte) error { func (w *Wallet) rewind() error { if s, ok := w.rw.(io.Seeker); ok { - if _, err := s.Seek(0, 0); err != nil { + if _, err := s.Seek(0, io.SeekStart); err != nil { return err } }