From 56b23b718d338f300a5410efd8a5f4d4332429fa Mon Sep 17 00:00:00 2001 From: Evgenii Stratonikov Date: Mon, 30 Nov 2020 12:06:36 +0300 Subject: [PATCH] fixedn: allow to parse big decimals --- cli/wallet/nep17.go | 8 +-- pkg/encoding/fixedn/decimal.go | 80 +++++++++++++++++++++++++++++ pkg/encoding/fixedn/decimal_test.go | 51 ++++++++++++++++++ pkg/encoding/fixedn/fixed8.go | 30 +---------- pkg/encoding/fixedn/fixed8_test.go | 15 ------ 5 files changed, 137 insertions(+), 47 deletions(-) create mode 100644 pkg/encoding/fixedn/decimal.go create mode 100644 pkg/encoding/fixedn/decimal_test.go diff --git a/cli/wallet/nep17.go b/cli/wallet/nep17.go index 8257d4019..dd9006bd8 100644 --- a/cli/wallet/nep17.go +++ b/cli/wallet/nep17.go @@ -391,14 +391,14 @@ func multiTransferNEP17(ctx *cli.Context) error { if err != nil { return cli.NewExitError(fmt.Errorf("invalid address: '%s'", ss[1]), 1) } - amount, err := fixedn.FixedNFromString(ss[2], int(token.Decimals)) + amount, err := fixedn.FromString(ss[2], int(token.Decimals)) if err != nil { return cli.NewExitError(fmt.Errorf("invalid amount: %w", err), 1) } recipients = append(recipients, client.TransferTarget{ Token: token.Hash, Address: addr, - Amount: amount, + Amount: amount.Int64(), }) } @@ -438,7 +438,7 @@ func transferNEP17(ctx *cli.Context) error { } } - amount, err := fixedn.FixedNFromString(ctx.String("amount"), int(token.Decimals)) + amount, err := fixedn.FromString(ctx.String("amount"), int(token.Decimals)) if err != nil { return cli.NewExitError(fmt.Errorf("invalid amount: %w", err), 1) } @@ -446,7 +446,7 @@ func transferNEP17(ctx *cli.Context) error { return signAndSendTransfer(ctx, c, acc, []client.TransferTarget{{ Token: token.Hash, Address: to, - Amount: amount, + Amount: amount.Int64(), }}) } diff --git a/pkg/encoding/fixedn/decimal.go b/pkg/encoding/fixedn/decimal.go new file mode 100644 index 000000000..f26976e47 --- /dev/null +++ b/pkg/encoding/fixedn/decimal.go @@ -0,0 +1,80 @@ +package fixedn + +import ( + "errors" + "fmt" + "math/big" + "strconv" + "strings" +) + +const maxAllowedPrecision = 16 + +// ErrInvalidFormat is returned when decimal format is invalid. +var ErrInvalidFormat = errors.New("invalid decimal format") + +var _pow10 []*big.Int + +func init() { + var p = int64(1) + for i := 0; i <= maxAllowedPrecision; i++ { + _pow10 = append(_pow10, big.NewInt(p)) + p *= 10 + } +} + +func pow10(n int) *big.Int { + last := len(_pow10) - 1 + if n <= last { + return _pow10[n] + } + p := new(big.Int) + p.Mul(_pow10[last], _pow10[1]) + for i := last + 1; i < n; i++ { + p.Mul(p, _pow10[1]) + } + return p +} + +// ToString converts big decimal with specified precision to string. +func ToString(bi *big.Int, precision int) string { + var dp, fp big.Int + dp.QuoRem(bi, pow10(precision), &fp) + + var s = dp.String() + if fp.Sign() == 0 { + return s + } + frac := fp.Uint64() + trimmed := 0 + for ; frac%10 == 0; frac /= 10 { + trimmed++ + } + return s + "." + fmt.Sprintf("%0"+strconv.FormatUint(uint64(precision-trimmed), 10)+"d", frac) +} + +// FromString converts string to a big decimal with specified precision. +func FromString(s string, precision int) (*big.Int, error) { + parts := strings.SplitN(s, ".", 2) + bi, ok := new(big.Int).SetString(parts[0], 10) + if !ok { + return nil, ErrInvalidFormat + } + bi.Mul(bi, pow10(precision)) + if len(parts) == 1 { + return bi, nil + } + + if len(parts[1]) > precision { + return nil, ErrInvalidFormat + } + fp, ok := new(big.Int).SetString(parts[1], 10) + if !ok { + return nil, ErrInvalidFormat + } + fp.Mul(fp, pow10(precision-len(parts[1]))) + if bi.Sign() == -1 { + return bi.Sub(bi, fp), nil + } + return bi.Add(bi, fp), nil +} diff --git a/pkg/encoding/fixedn/decimal_test.go b/pkg/encoding/fixedn/decimal_test.go new file mode 100644 index 000000000..be676e85f --- /dev/null +++ b/pkg/encoding/fixedn/decimal_test.go @@ -0,0 +1,51 @@ +package fixedn + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDecimalFromStringGood(t *testing.T) { + var testCases = []struct { + bi *big.Int + prec int + s string + }{ + {big.NewInt(123), 2, "1.23"}, + {big.NewInt(12300), 2, "123"}, + {big.NewInt(1234500000), 8, "12.345"}, + {big.NewInt(-12345), 3, "-12.345"}, + {big.NewInt(35), 8, "0.00000035"}, + {big.NewInt(1230), 5, "0.0123"}, + {big.NewInt(123456789), 20, "0.00000000000123456789"}, + } + for _, tc := range testCases { + t.Run(tc.s, func(t *testing.T) { + s := ToString(tc.bi, tc.prec) + require.Equal(t, tc.s, s) + + bi, err := FromString(s, tc.prec) + require.NoError(t, err) + require.Equal(t, tc.bi, bi) + }) + } +} + +func TestDecimalFromStringBad(t *testing.T) { + var errCases = []struct { + s string + prec int + }{ + {"12A", 1}, + {"12.345", 2}, + {"12.3A", 2}, + } + for _, tc := range errCases { + t.Run(tc.s, func(t *testing.T) { + _, err := FromString(tc.s, tc.prec) + require.Error(t, err) + }) + } +} diff --git a/pkg/encoding/fixedn/fixed8.go b/pkg/encoding/fixedn/fixed8.go index 80103b9d5..014a6e888 100644 --- a/pkg/encoding/fixedn/fixed8.go +++ b/pkg/encoding/fixedn/fixed8.go @@ -3,7 +3,6 @@ package fixedn import ( "encoding/json" "errors" - "math" "strconv" "strings" @@ -72,36 +71,11 @@ func Fixed8FromFloat(val float64) Fixed8 { // Fixed8FromString parses s which must be a fixed point number // with precision up to 10^-8 func Fixed8FromString(s string) (Fixed8, error) { - num, err := FixedNFromString(s, precision) + num, err := FromString(s, precision) if err != nil { return 0, err } - return Fixed8(num), err -} - -// FixedNFromString parses s which must be a fixed point number -// with precision 10^-d. -func FixedNFromString(s string, precision int) (int64, error) { - parts := strings.SplitN(s, ".", 2) - d := int64(math.Pow10(precision)) - ip, err := strconv.ParseInt(parts[0], 10, 64) - if err != nil { - return 0, errInvalidString - } else if len(parts) == 1 { - return ip * d, nil - } - - fp, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil || fp >= d { - return 0, errInvalidString - } - for i := len(parts[1]); i < precision; i++ { - fp *= 10 - } - if ip < 0 { - return ip*d - fp, nil - } - return ip*d + fp, nil + return Fixed8(num.Int64()), err } // UnmarshalJSON implements the json unmarshaller interface. diff --git a/pkg/encoding/fixedn/fixed8_test.go b/pkg/encoding/fixedn/fixed8_test.go index 5560e6a4d..4bbaf2af2 100644 --- a/pkg/encoding/fixedn/fixed8_test.go +++ b/pkg/encoding/fixedn/fixed8_test.go @@ -8,7 +8,6 @@ import ( "github.com/nspcc-dev/neo-go/internal/testserdes" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" ) @@ -85,20 +84,6 @@ func TestFixed8FromString(t *testing.T) { assert.Error(t, err) } -func TestFixedNFromString(t *testing.T) { - val := "123.456" - num, err := FixedNFromString(val, 3) - require.NoError(t, err) - require.EqualValues(t, 123456, num) - - num, err = FixedNFromString(val, 4) - require.NoError(t, err) - require.EqualValues(t, 1234560, num) - - _, err = FixedNFromString(val, 2) - require.Error(t, err) -} - func TestSatoshi(t *testing.T) { satoshif8 := Satoshi() assert.Equal(t, "0.00000001", satoshif8.String())