diff --git a/cli/wallet_test.go b/cli/wallet_test.go index ddf57d05c..098660d42 100644 --- a/cli/wallet_test.go +++ b/cli/wallet_test.go @@ -19,6 +19,38 @@ import ( "github.com/stretchr/testify/require" ) +func TestWalletAccountRemove(t *testing.T) { + tmpDir, err := ioutil.TempDir("", "neogo.test.walletinit") + require.NoError(t, err) + t.Cleanup(func() { + os.RemoveAll(tmpDir) + }) + + e := newExecutor(t, false) + + walletPath := path.Join(tmpDir, "wallet.json") + e.In.WriteString("acc1\r") + e.In.WriteString("pass\r") + e.In.WriteString("pass\r") + e.Run(t, "neo-go", "wallet", "init", "--wallet", walletPath, "--account") + + e.In.WriteString("acc2\r") + e.In.WriteString("pass\r") + e.In.WriteString("pass\r") + e.Run(t, "neo-go", "wallet", "create", "--wallet", walletPath) + + w, err := wallet.NewWalletFromFile(walletPath) + require.NoError(t, err) + + addr := w.Accounts[0].Address + e.Run(t, "neo-go", "wallet", "remove", "--wallet", walletPath, + "--address", addr, "--force") + + rawWallet, err := ioutil.ReadFile(walletPath) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(rawWallet, new(wallet.Wallet))) +} + func TestWalletInit(t *testing.T) { tmpDir, err := ioutil.TempDir("", "neogo.test.walletinit") require.NoError(t, err) diff --git a/pkg/wallet/wallet.go b/pkg/wallet/wallet.go index c42e65ef1..ca32a4790 100644 --- a/pkg/wallet/wallet.go +++ b/pkg/wallet/wallet.go @@ -143,25 +143,45 @@ func (w *Wallet) Path() string { // that is responsible for saving the data. This can // be a buffer, file, etc.. func (w *Wallet) Save() error { - if err := w.rewind(); err != nil { + data, err := json.Marshal(w) + if err != nil { return err } - return json.NewEncoder(w.rw).Encode(w) + + return w.writeRaw(data) } // savePretty saves wallet in a beautiful JSON. func (w *Wallet) savePretty() error { + data, err := json.MarshalIndent(w, "", " ") + if err != nil { + return err + } + + return w.writeRaw(data) +} + +func (w *Wallet) writeRaw(data []byte) error { if err := w.rewind(); err != nil { return err } - enc := json.NewEncoder(w.rw) - enc.SetIndent("", " ") - return enc.Encode(w) + + _, err := w.rw.Write(data) + if err != nil { + return err + } + + if f, ok := w.rw.(*os.File); ok { + if err := f.Truncate(int64(len(data))); err != nil { + return err + } + } + return nil } func (w *Wallet) rewind() error { if s, ok := w.rw.(io.Seeker); ok { - if _, err := s.Seek(0, 0); err != nil { + if _, err := s.Seek(0, io.SeekStart); err != nil { return err } }