diff --git a/cli/wallet/wallet.go b/cli/wallet/wallet.go index bc89e05b6..50e852a70 100644 --- a/cli/wallet/wallet.go +++ b/cli/wallet/wallet.go @@ -519,7 +519,7 @@ loop: } func importMultisig(ctx *cli.Context) error { - wall, _, err := openWallet(ctx, true) + wall, pass, err := openWallet(ctx, true) if err != nil { return cli.NewExitError(err, 1) } @@ -540,7 +540,12 @@ func importMultisig(ctx *cli.Context) error { } } - acc, err := newAccountFromWIF(ctx.App.Writer, ctx.String("wif"), wall.Scrypt) + var label *string + if ctx.IsSet("name") { + l := ctx.String("name") + label = &l + } + acc, err := newAccountFromWIF(ctx.App.Writer, ctx.String("wif"), wall.Scrypt, label, pass) if err != nil { return cli.NewExitError(err, 1) } @@ -549,9 +554,6 @@ func importMultisig(ctx *cli.Context) error { return cli.NewExitError(err, 1) } - if acc.Label == "" { - acc.Label = ctx.String("name") - } if err := addAccountAndSave(wall, acc); err != nil { return cli.NewExitError(err, 1) } @@ -563,7 +565,7 @@ func importDeployed(ctx *cli.Context) error { if err := cmdargs.EnsureNone(ctx); err != nil { return err } - wall, _, err := openWallet(ctx, true) + wall, pass, err := openWallet(ctx, true) if err != nil { return cli.NewExitError(err, 1) } @@ -574,7 +576,12 @@ func importDeployed(ctx *cli.Context) error { return cli.NewExitError("contract hash was not provided", 1) } - acc, err := newAccountFromWIF(ctx.App.Writer, ctx.String("wif"), wall.Scrypt) + var label *string + if ctx.IsSet("name") { + l := ctx.String("name") + label = &l + } + acc, err := newAccountFromWIF(ctx.App.Writer, ctx.String("wif"), wall.Scrypt, label, pass) if err != nil { return cli.NewExitError(err, 1) } @@ -606,9 +613,6 @@ func importDeployed(ctx *cli.Context) error { } acc.Contract.Deployed = true - if acc.Label == "" { - acc.Label = ctx.String("name") - } if err := addAccountAndSave(wall, acc); err != nil { return cli.NewExitError(err, 1) } @@ -620,13 +624,19 @@ func importWallet(ctx *cli.Context) error { if err := cmdargs.EnsureNone(ctx); err != nil { return err } - wall, _, err := openWallet(ctx, true) + wall, pass, err := openWallet(ctx, true) if err != nil { return cli.NewExitError(err, 1) } defer wall.Close() - acc, err := newAccountFromWIF(ctx.App.Writer, ctx.String("wif"), wall.Scrypt) + var label *string + if ctx.IsSet("name") { + l := ctx.String("name") + label = &l + } + + acc, err := newAccountFromWIF(ctx.App.Writer, ctx.String("wif"), wall.Scrypt, label, pass) if err != nil { return cli.NewExitError(err, 1) } @@ -639,9 +649,6 @@ func importWallet(ctx *cli.Context) error { acc.Contract.Script = ctr } - if acc.Label == "" { - acc.Label = ctx.String("name") - } if err := addAccountAndSave(wall, acc); err != nil { return cli.NewExitError(err, 1) } @@ -843,7 +850,7 @@ func createWallet(ctx *cli.Context) error { } func readAccountInfo() (string, string, error) { - name, err := input.ReadLine("Enter the name of the account > ") + name, err := readAccountName() if err != nil { return "", "", err } @@ -854,6 +861,10 @@ func readAccountInfo() (string, string, error) { return name, phrase, nil } +func readAccountName() (string, error) { + return input.ReadLine("Enter the name of the account > ") +} + func readNewPassword() (string, error) { phrase, err := input.ReadPassword(EnterNewPasswordPrompt) if err != nil { @@ -966,16 +977,33 @@ func ReadWalletConfig(configPath string) (*config.Wallet, error) { return cfg, nil } -func newAccountFromWIF(w io.Writer, wif string, scrypt keys.ScryptParams) (*wallet.Account, error) { +func newAccountFromWIF(w io.Writer, wif string, scrypt keys.ScryptParams, label *string, pass *string) (*wallet.Account, error) { + var ( + phrase, name string + err error + ) + if pass != nil { + phrase = *pass + } + if label != nil { + name = *label + } // note: NEP2 strings always have length of 58 even though // base58 strings can have different lengths even if slice lengths are equal if len(wif) == 58 { - pass, err := input.ReadPassword(EnterPasswordPrompt) - if err != nil { - return nil, fmt.Errorf("Error reading password: %w", err) + if pass == nil { + phrase, err = input.ReadPassword(EnterPasswordPrompt) + if err != nil { + return nil, fmt.Errorf("error reading password: %w", err) + } } - return wallet.NewAccountFromEncryptedWIF(wif, pass, scrypt) + acc, err := wallet.NewAccountFromEncryptedWIF(wif, phrase, scrypt) + if err != nil { + return nil, err + } + acc.Label = name + return acc, nil } acc, err := wallet.NewAccountFromWIF(wif) @@ -984,13 +1012,21 @@ func newAccountFromWIF(w io.Writer, wif string, scrypt keys.ScryptParams) (*wall } fmt.Fprintln(w, "Provided WIF was unencrypted. Wallet can contain only encrypted keys.") - name, pass, err := readAccountInfo() - if err != nil { - return nil, err + if label == nil { + name, err = readAccountName() + if err != nil { + return nil, fmt.Errorf("failed to read account label: %w", err) + } + } + if pass == nil { + phrase, err = readNewPassword() + if err != nil { + return nil, fmt.Errorf("failed to read new password: %w", err) + } } acc.Label = name - if err := acc.Encrypt(pass, scrypt); err != nil { + if err := acc.Encrypt(phrase, scrypt); err != nil { return nil, err } diff --git a/cli/wallet/wallet_test.go b/cli/wallet/wallet_test.go index 27bb4ec98..36b9c5204 100644 --- a/cli/wallet/wallet_test.go +++ b/cli/wallet/wallet_test.go @@ -318,7 +318,7 @@ func TestWalletInit(t *testing.T) { configPath := filepath.Join(tmp, "config.yaml") cfg := config.Wallet{ Path: walletPath, - Password: "pass", // This pass won't be taken into account. + Password: "qwerty", } res, err := yaml.Marshal(cfg) require.NoError(t, err) @@ -326,12 +326,26 @@ func TestWalletInit(t *testing.T) { priv, err = keys.NewPrivateKey() require.NoError(t, err) e.In.WriteString("test_account_4\r") - e.In.WriteString("qwerty\r") - e.In.WriteString("qwerty\r") e.Run(t, "neo-go", "wallet", "import", "--wallet-config", configPath, "--wif", priv.WIF(), "--contract", "0a0b0c0d") check(t, "test_account_4", "qwerty") }) + t.Run("from wallet config with account name argument", func(t *testing.T) { + tmp := t.TempDir() + configPath := filepath.Join(tmp, "config.yaml") + cfg := config.Wallet{ + Path: walletPath, + Password: "qwerty", + } + res, err := yaml.Marshal(cfg) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, res, 0666)) + priv, err = keys.NewPrivateKey() + require.NoError(t, err) + e.Run(t, "neo-go", "wallet", "import", + "--wallet-config", configPath, "--wif", priv.WIF(), "--contract", "0a0b0c0d", "--name", "test_account_5") + check(t, "test_account_5", "qwerty") + }) }) }) t.Run("EncryptedWIF", func(t *testing.T) { @@ -586,21 +600,25 @@ func TestWalletImportDeployed(t *testing.T) { e.In.WriteString("acc\rpass\rpass\r") e.Run(t, "neo-go", "wallet", "import-deployed", "--rpc-endpoint", "http://"+e.RPC.Addresses()[0], - "--wallet", walletPath, "--wif", priv.WIF(), "--name", "my_acc", + "--wallet", walletPath, "--wif", priv.WIF(), "--contract", h.StringLE()) - w, err := wallet.NewWalletFromFile(walletPath) - require.NoError(t, err) - require.Equal(t, 1, len(w.Accounts)) - contractAddr := w.Accounts[0].Address - require.Equal(t, address.Uint160ToString(h), contractAddr) - require.True(t, w.Accounts[0].Contract.Deployed) + contractAddr := address.Uint160ToString(h) + checkDeployed := func(t *testing.T) { + w, err := wallet.NewWalletFromFile(walletPath) + require.NoError(t, err) + require.Equal(t, 1, len(w.Accounts)) + actualAddr := w.Accounts[0].Address + require.Equal(t, contractAddr, actualAddr) + require.True(t, w.Accounts[0].Contract.Deployed) + } + checkDeployed(t) t.Run("re-importing", func(t *testing.T) { e.In.WriteString("acc\rpass\rpass\r") e.RunWithError(t, "neo-go", "wallet", "import-deployed", "--rpc-endpoint", "http://"+e.RPC.Addresses()[0], - "--wallet", walletPath, "--wif", priv.WIF(), "--name", "my_acc", + "--wallet", walletPath, "--wif", priv.WIF(), "--contract", h.StringLE()) }) @@ -630,6 +648,35 @@ func TestWalletImportDeployed(t *testing.T) { b, _ = e.Chain.GetGoverningTokenBalance(privTo.GetScriptHash()) require.Equal(t, big.NewInt(1), b) }) + + t.Run("import with name argument", func(t *testing.T) { + e.Run(t, "neo-go", "wallet", "remove", + "--wallet", walletPath, "--address", address.Uint160ToString(h), "--force") + e.In.WriteString("pass\rpass\r") + e.Run(t, "neo-go", "wallet", "import-deployed", + "--rpc-endpoint", "http://"+e.RPC.Addresses()[0], + "--wallet", walletPath, "--wif", priv.WIF(), + "--contract", h.StringLE(), "--name", "acc") + checkDeployed(t) + }) + + t.Run("import with name argument and wallet config", func(t *testing.T) { + e.Run(t, "neo-go", "wallet", "remove", + "--wallet", walletPath, "--address", address.Uint160ToString(h), "--force") + configPath := filepath.Join(t.TempDir(), "wallet-config.yaml") + cfg := &config.Wallet{ + Path: walletPath, + Password: "pass", + } + bytes, err := yaml.Marshal(cfg) + require.NoError(t, err) + require.NoError(t, os.WriteFile(configPath, bytes, os.ModePerm)) + e.Run(t, "neo-go", "wallet", "import-deployed", + "--rpc-endpoint", "http://"+e.RPC.Addresses()[0], + "--wallet-config", configPath, "--wif", priv.WIF(), + "--contract", h.StringLE(), "--name", "acc") + checkDeployed(t) + }) } func TestStripKeys(t *testing.T) {