cli: implement AddressFlag

This makes code less verbose and performs all parsing
before invoking main function.
This commit is contained in:
Evgenii Stratonikov 2020-03-03 23:09:47 +03:00
parent f8eee778f4
commit 9be4312d8d
3 changed files with 103 additions and 23 deletions

78
cli/flags/address.go Normal file
View file

@ -0,0 +1,78 @@
package flags
import (
"flag"
"fmt"
"strings"
"github.com/nspcc-dev/neo-go/pkg/encoding/address"
"github.com/nspcc-dev/neo-go/pkg/util"
"github.com/urfave/cli"
)
// Address is a wrapper for Uint160 with flag.Value methods.
type Address util.Uint160
// AddressFlag is a flag with type string
type AddressFlag struct {
Name string
Usage string
Value Address
}
var (
_ flag.Value = (*Address)(nil)
_ cli.Flag = AddressFlag{}
)
// String implements fmt.Stringer interface.
func (a Address) String() string {
return address.Uint160ToString(util.Uint160(a))
}
// Set implements flag.Value interface.
func (a *Address) Set(s string) error {
addr, err := address.StringToUint160(s)
if err != nil {
return cli.NewExitError(err, 1)
}
*a = Address(addr)
return nil
}
// Uint160 casts address to Uint160.
func (a *Address) Uint160() (u util.Uint160) {
copy(u[:], a[:])
return
}
// String returns a readable representation of this value
// (for usage defaults)
func (f AddressFlag) String() string {
var names []string
eachName(f.Name, func(name string) {
names = append(names, getNameHelp(name))
})
return strings.Join(names, ", ") + "\t" + f.Usage
}
func getNameHelp(name string) string {
if len(name) == 1 {
return fmt.Sprintf("-%s value", name)
}
return fmt.Sprintf("--%s value", name)
}
// GetName returns the name of the flag
func (f AddressFlag) GetName() string {
return f.Name
}
// Apply populates the flag given the flag set and environment
// Ignores errors
func (f AddressFlag) Apply(set *flag.FlagSet) {
eachName(f.Name, func(name string) {
set.Var(&f.Value, name, f.Usage)
})
}

11
cli/flags/util.go Normal file
View file

@ -0,0 +1,11 @@
package flags
import "strings"
func eachName(longName string, fn func(string)) {
parts := strings.Split(longName, ",")
for _, name := range parts {
name = strings.Trim(name, " ")
fn(name)
}
}

View file

@ -11,6 +11,7 @@ import (
"strings"
"syscall"
"github.com/nspcc-dev/neo-go/cli/flags"
"github.com/nspcc-dev/neo-go/pkg/core"
"github.com/nspcc-dev/neo-go/pkg/core/transaction"
"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
@ -74,7 +75,7 @@ func NewCommands() []cli.Command {
walletPathFlag,
rpcFlag,
timeoutFlag,
cli.StringFlag{
flags.AddressFlag{
Name: "address, a",
Usage: "Address to claim GAS for",
},
@ -162,11 +163,11 @@ func NewCommands() []cli.Command {
rpcFlag,
timeoutFlag,
outFlag,
cli.StringFlag{
flags.AddressFlag{
Name: "from",
Usage: "Address to send an asset from",
},
cli.StringFlag{
flags.AddressFlag{
Name: "to",
Usage: "Address to send an asset to",
},
@ -196,15 +197,11 @@ func claimGas(ctx *cli.Context) error {
}
defer wall.Close()
addr := ctx.String("address")
scriptHash, err := address.StringToUint160(addr)
if err != nil {
return cli.NewExitError(err, 1)
}
addrFlag := ctx.Generic("address").(*flags.Address)
scriptHash := addrFlag.Uint160()
acc := wall.GetAccount(scriptHash)
if acc == nil {
return cli.NewExitError(fmt.Errorf("wallet contains no account for '%s'", addr), 1)
return cli.NewExitError(fmt.Errorf("wallet contains no account for '%s'", addrFlag), 1)
}
pass, err := readPassword("Enter password > ")
@ -221,7 +218,7 @@ func claimGas(ctx *cli.Context) error {
if err != nil {
return cli.NewExitError(err, 1)
}
info, err := c.GetClaimable(addr)
info, err := c.GetClaimable(scriptHash.String())
if err != nil {
return cli.NewExitError(err, 1)
} else if info.Unclaimed == 0 || len(info.Spents) == 0 {
@ -397,14 +394,11 @@ func transferAsset(ctx *cli.Context) error {
}
defer wall.Close()
from := ctx.String("from")
addr, err := address.StringToUint160(from)
if err != nil {
return cli.NewExitError("invalid address", 1)
}
acc := wall.GetAccount(addr)
fromFlag := ctx.Generic("from").(*flags.Address)
from := fromFlag.Uint160()
acc := wall.GetAccount(from)
if acc == nil {
return cli.NewExitError(fmt.Errorf("wallet contains no account for '%s'", addr), 1)
return cli.NewExitError(fmt.Errorf("wallet contains no account for '%s'", from), 1)
}
asset, err := getAssetID(ctx.String("asset"))
@ -433,14 +427,11 @@ func transferAsset(ctx *cli.Context) error {
}
tx := transaction.NewContractTX()
if err := request.AddInputsAndUnspentsToTx(tx, from, asset, amount, c); err != nil {
if err := request.AddInputsAndUnspentsToTx(tx, fromFlag.String(), asset, amount, c); err != nil {
return cli.NewExitError(err, 1)
}
toAddr, err := address.StringToUint160(ctx.String("to"))
if err != nil {
return cli.NewExitError(err, 1)
}
toAddr := ctx.Generic("to").(*flags.Address).Uint160()
tx.AddOutput(&transaction.Output{
AssetID: asset,
Amount: amount,