fixedn: allow to parse big decimals

This commit is contained in:
Evgenii Stratonikov 2020-11-30 12:06:36 +03:00
parent e4c3339c91
commit 56b23b718d
5 changed files with 137 additions and 47 deletions

View file

@ -391,14 +391,14 @@ func multiTransferNEP17(ctx *cli.Context) error {
if err != nil {
return cli.NewExitError(fmt.Errorf("invalid address: '%s'", ss[1]), 1)
}
amount, err := fixedn.FixedNFromString(ss[2], int(token.Decimals))
amount, err := fixedn.FromString(ss[2], int(token.Decimals))
if err != nil {
return cli.NewExitError(fmt.Errorf("invalid amount: %w", err), 1)
}
recipients = append(recipients, client.TransferTarget{
Token: token.Hash,
Address: addr,
Amount: amount,
Amount: amount.Int64(),
})
}
@ -438,7 +438,7 @@ func transferNEP17(ctx *cli.Context) error {
}
}
amount, err := fixedn.FixedNFromString(ctx.String("amount"), int(token.Decimals))
amount, err := fixedn.FromString(ctx.String("amount"), int(token.Decimals))
if err != nil {
return cli.NewExitError(fmt.Errorf("invalid amount: %w", err), 1)
}
@ -446,7 +446,7 @@ func transferNEP17(ctx *cli.Context) error {
return signAndSendTransfer(ctx, c, acc, []client.TransferTarget{{
Token: token.Hash,
Address: to,
Amount: amount,
Amount: amount.Int64(),
}})
}

View file

@ -0,0 +1,80 @@
package fixedn
import (
"errors"
"fmt"
"math/big"
"strconv"
"strings"
)
const maxAllowedPrecision = 16
// ErrInvalidFormat is returned when decimal format is invalid.
var ErrInvalidFormat = errors.New("invalid decimal format")
var _pow10 []*big.Int
func init() {
var p = int64(1)
for i := 0; i <= maxAllowedPrecision; i++ {
_pow10 = append(_pow10, big.NewInt(p))
p *= 10
}
}
func pow10(n int) *big.Int {
last := len(_pow10) - 1
if n <= last {
return _pow10[n]
}
p := new(big.Int)
p.Mul(_pow10[last], _pow10[1])
for i := last + 1; i < n; i++ {
p.Mul(p, _pow10[1])
}
return p
}
// ToString converts big decimal with specified precision to string.
func ToString(bi *big.Int, precision int) string {
var dp, fp big.Int
dp.QuoRem(bi, pow10(precision), &fp)
var s = dp.String()
if fp.Sign() == 0 {
return s
}
frac := fp.Uint64()
trimmed := 0
for ; frac%10 == 0; frac /= 10 {
trimmed++
}
return s + "." + fmt.Sprintf("%0"+strconv.FormatUint(uint64(precision-trimmed), 10)+"d", frac)
}
// FromString converts string to a big decimal with specified precision.
func FromString(s string, precision int) (*big.Int, error) {
parts := strings.SplitN(s, ".", 2)
bi, ok := new(big.Int).SetString(parts[0], 10)
if !ok {
return nil, ErrInvalidFormat
}
bi.Mul(bi, pow10(precision))
if len(parts) == 1 {
return bi, nil
}
if len(parts[1]) > precision {
return nil, ErrInvalidFormat
}
fp, ok := new(big.Int).SetString(parts[1], 10)
if !ok {
return nil, ErrInvalidFormat
}
fp.Mul(fp, pow10(precision-len(parts[1])))
if bi.Sign() == -1 {
return bi.Sub(bi, fp), nil
}
return bi.Add(bi, fp), nil
}

View file

@ -0,0 +1,51 @@
package fixedn
import (
"math/big"
"testing"
"github.com/stretchr/testify/require"
)
func TestDecimalFromStringGood(t *testing.T) {
var testCases = []struct {
bi *big.Int
prec int
s string
}{
{big.NewInt(123), 2, "1.23"},
{big.NewInt(12300), 2, "123"},
{big.NewInt(1234500000), 8, "12.345"},
{big.NewInt(-12345), 3, "-12.345"},
{big.NewInt(35), 8, "0.00000035"},
{big.NewInt(1230), 5, "0.0123"},
{big.NewInt(123456789), 20, "0.00000000000123456789"},
}
for _, tc := range testCases {
t.Run(tc.s, func(t *testing.T) {
s := ToString(tc.bi, tc.prec)
require.Equal(t, tc.s, s)
bi, err := FromString(s, tc.prec)
require.NoError(t, err)
require.Equal(t, tc.bi, bi)
})
}
}
func TestDecimalFromStringBad(t *testing.T) {
var errCases = []struct {
s string
prec int
}{
{"12A", 1},
{"12.345", 2},
{"12.3A", 2},
}
for _, tc := range errCases {
t.Run(tc.s, func(t *testing.T) {
_, err := FromString(tc.s, tc.prec)
require.Error(t, err)
})
}
}

View file

@ -3,7 +3,6 @@ package fixedn
import (
"encoding/json"
"errors"
"math"
"strconv"
"strings"
@ -72,36 +71,11 @@ func Fixed8FromFloat(val float64) Fixed8 {
// Fixed8FromString parses s which must be a fixed point number
// with precision up to 10^-8
func Fixed8FromString(s string) (Fixed8, error) {
num, err := FixedNFromString(s, precision)
num, err := FromString(s, precision)
if err != nil {
return 0, err
}
return Fixed8(num), err
}
// FixedNFromString parses s which must be a fixed point number
// with precision 10^-d.
func FixedNFromString(s string, precision int) (int64, error) {
parts := strings.SplitN(s, ".", 2)
d := int64(math.Pow10(precision))
ip, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return 0, errInvalidString
} else if len(parts) == 1 {
return ip * d, nil
}
fp, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil || fp >= d {
return 0, errInvalidString
}
for i := len(parts[1]); i < precision; i++ {
fp *= 10
}
if ip < 0 {
return ip*d - fp, nil
}
return ip*d + fp, nil
return Fixed8(num.Int64()), err
}
// UnmarshalJSON implements the json unmarshaller interface.

View file

@ -8,7 +8,6 @@ import (
"github.com/nspcc-dev/neo-go/internal/testserdes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2"
)
@ -85,20 +84,6 @@ func TestFixed8FromString(t *testing.T) {
assert.Error(t, err)
}
func TestFixedNFromString(t *testing.T) {
val := "123.456"
num, err := FixedNFromString(val, 3)
require.NoError(t, err)
require.EqualValues(t, 123456, num)
num, err = FixedNFromString(val, 4)
require.NoError(t, err)
require.EqualValues(t, 1234560, num)
_, err = FixedNFromString(val, 2)
require.Error(t, err)
}
func TestSatoshi(t *testing.T) {
satoshif8 := Satoshi()
assert.Equal(t, "0.00000001", satoshif8.String())