diff --git a/cli/flags/address.go b/cli/flags/address.go new file mode 100644 index 000000000..7468960d1 --- /dev/null +++ b/cli/flags/address.go @@ -0,0 +1,78 @@ +package flags + +import ( + "flag" + "fmt" + "strings" + + "github.com/nspcc-dev/neo-go/pkg/encoding/address" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/urfave/cli" +) + +// Address is a wrapper for Uint160 with flag.Value methods. +type Address util.Uint160 + +// AddressFlag is a flag with type string +type AddressFlag struct { + Name string + Usage string + Value Address +} + +var ( + _ flag.Value = (*Address)(nil) + _ cli.Flag = AddressFlag{} +) + +// String implements fmt.Stringer interface. +func (a Address) String() string { + return address.Uint160ToString(util.Uint160(a)) +} + +// Set implements flag.Value interface. +func (a *Address) Set(s string) error { + addr, err := address.StringToUint160(s) + if err != nil { + return cli.NewExitError(err, 1) + } + *a = Address(addr) + return nil +} + +// Uint160 casts address to Uint160. +func (a *Address) Uint160() (u util.Uint160) { + copy(u[:], a[:]) + return +} + +// String returns a readable representation of this value +// (for usage defaults) +func (f AddressFlag) String() string { + var names []string + eachName(f.Name, func(name string) { + names = append(names, getNameHelp(name)) + }) + + return strings.Join(names, ", ") + "\t" + f.Usage +} + +func getNameHelp(name string) string { + if len(name) == 1 { + return fmt.Sprintf("-%s value", name) + } + return fmt.Sprintf("--%s value", name) +} + +// GetName returns the name of the flag +func (f AddressFlag) GetName() string { + return f.Name +} + +// Apply populates the flag given the flag set and environment +// Ignores errors +func (f AddressFlag) Apply(set *flag.FlagSet) { + eachName(f.Name, func(name string) { + set.Var(&f.Value, name, f.Usage) + }) +} diff --git a/cli/flags/util.go b/cli/flags/util.go new file mode 100644 index 000000000..6b215b5e2 --- /dev/null +++ b/cli/flags/util.go @@ -0,0 +1,11 @@ +package flags + +import "strings" + +func eachName(longName string, fn func(string)) { + parts := strings.Split(longName, ",") + for _, name := range parts { + name = strings.Trim(name, " ") + fn(name) + } +} diff --git a/cli/wallet/wallet.go b/cli/wallet/wallet.go index d816d2afa..483eb6131 100644 --- a/cli/wallet/wallet.go +++ b/cli/wallet/wallet.go @@ -11,6 +11,7 @@ import ( "strings" "syscall" + "github.com/nspcc-dev/neo-go/cli/flags" "github.com/nspcc-dev/neo-go/pkg/core" "github.com/nspcc-dev/neo-go/pkg/core/transaction" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" @@ -74,7 +75,7 @@ func NewCommands() []cli.Command { walletPathFlag, rpcFlag, timeoutFlag, - cli.StringFlag{ + flags.AddressFlag{ Name: "address, a", Usage: "Address to claim GAS for", }, @@ -162,11 +163,11 @@ func NewCommands() []cli.Command { rpcFlag, timeoutFlag, outFlag, - cli.StringFlag{ + flags.AddressFlag{ Name: "from", Usage: "Address to send an asset from", }, - cli.StringFlag{ + flags.AddressFlag{ Name: "to", Usage: "Address to send an asset to", }, @@ -196,15 +197,11 @@ func claimGas(ctx *cli.Context) error { } defer wall.Close() - addr := ctx.String("address") - scriptHash, err := address.StringToUint160(addr) - if err != nil { - return cli.NewExitError(err, 1) - } - + addrFlag := ctx.Generic("address").(*flags.Address) + scriptHash := addrFlag.Uint160() acc := wall.GetAccount(scriptHash) if acc == nil { - return cli.NewExitError(fmt.Errorf("wallet contains no account for '%s'", addr), 1) + return cli.NewExitError(fmt.Errorf("wallet contains no account for '%s'", addrFlag), 1) } pass, err := readPassword("Enter password > ") @@ -221,7 +218,7 @@ func claimGas(ctx *cli.Context) error { if err != nil { return cli.NewExitError(err, 1) } - info, err := c.GetClaimable(addr) + info, err := c.GetClaimable(scriptHash.String()) if err != nil { return cli.NewExitError(err, 1) } else if info.Unclaimed == 0 || len(info.Spents) == 0 { @@ -397,14 +394,11 @@ func transferAsset(ctx *cli.Context) error { } defer wall.Close() - from := ctx.String("from") - addr, err := address.StringToUint160(from) - if err != nil { - return cli.NewExitError("invalid address", 1) - } - acc := wall.GetAccount(addr) + fromFlag := ctx.Generic("from").(*flags.Address) + from := fromFlag.Uint160() + acc := wall.GetAccount(from) if acc == nil { - return cli.NewExitError(fmt.Errorf("wallet contains no account for '%s'", addr), 1) + return cli.NewExitError(fmt.Errorf("wallet contains no account for '%s'", from), 1) } asset, err := getAssetID(ctx.String("asset")) @@ -433,14 +427,11 @@ func transferAsset(ctx *cli.Context) error { } tx := transaction.NewContractTX() - if err := request.AddInputsAndUnspentsToTx(tx, from, asset, amount, c); err != nil { + if err := request.AddInputsAndUnspentsToTx(tx, fromFlag.String(), asset, amount, c); err != nil { return cli.NewExitError(err, 1) } - toAddr, err := address.StringToUint160(ctx.String("to")) - if err != nil { - return cli.NewExitError(err, 1) - } + toAddr := ctx.Generic("to").(*flags.Address).Uint160() tx.AddOutput(&transaction.Output{ AssetID: asset, Amount: amount,