forked from TrueCloudLab/neoneo-go
fixedn: allow to parse big decimals
This commit is contained in:
parent
e4c3339c91
commit
56b23b718d
5 changed files with 137 additions and 47 deletions
|
@ -391,14 +391,14 @@ func multiTransferNEP17(ctx *cli.Context) error {
|
|||
if err != nil {
|
||||
return cli.NewExitError(fmt.Errorf("invalid address: '%s'", ss[1]), 1)
|
||||
}
|
||||
amount, err := fixedn.FixedNFromString(ss[2], int(token.Decimals))
|
||||
amount, err := fixedn.FromString(ss[2], int(token.Decimals))
|
||||
if err != nil {
|
||||
return cli.NewExitError(fmt.Errorf("invalid amount: %w", err), 1)
|
||||
}
|
||||
recipients = append(recipients, client.TransferTarget{
|
||||
Token: token.Hash,
|
||||
Address: addr,
|
||||
Amount: amount,
|
||||
Amount: amount.Int64(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -438,7 +438,7 @@ func transferNEP17(ctx *cli.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
amount, err := fixedn.FixedNFromString(ctx.String("amount"), int(token.Decimals))
|
||||
amount, err := fixedn.FromString(ctx.String("amount"), int(token.Decimals))
|
||||
if err != nil {
|
||||
return cli.NewExitError(fmt.Errorf("invalid amount: %w", err), 1)
|
||||
}
|
||||
|
@ -446,7 +446,7 @@ func transferNEP17(ctx *cli.Context) error {
|
|||
return signAndSendTransfer(ctx, c, acc, []client.TransferTarget{{
|
||||
Token: token.Hash,
|
||||
Address: to,
|
||||
Amount: amount,
|
||||
Amount: amount.Int64(),
|
||||
}})
|
||||
}
|
||||
|
||||
|
|
80
pkg/encoding/fixedn/decimal.go
Normal file
80
pkg/encoding/fixedn/decimal.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package fixedn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const maxAllowedPrecision = 16
|
||||
|
||||
// ErrInvalidFormat is returned when decimal format is invalid.
|
||||
var ErrInvalidFormat = errors.New("invalid decimal format")
|
||||
|
||||
var _pow10 []*big.Int
|
||||
|
||||
func init() {
|
||||
var p = int64(1)
|
||||
for i := 0; i <= maxAllowedPrecision; i++ {
|
||||
_pow10 = append(_pow10, big.NewInt(p))
|
||||
p *= 10
|
||||
}
|
||||
}
|
||||
|
||||
func pow10(n int) *big.Int {
|
||||
last := len(_pow10) - 1
|
||||
if n <= last {
|
||||
return _pow10[n]
|
||||
}
|
||||
p := new(big.Int)
|
||||
p.Mul(_pow10[last], _pow10[1])
|
||||
for i := last + 1; i < n; i++ {
|
||||
p.Mul(p, _pow10[1])
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// ToString converts big decimal with specified precision to string.
|
||||
func ToString(bi *big.Int, precision int) string {
|
||||
var dp, fp big.Int
|
||||
dp.QuoRem(bi, pow10(precision), &fp)
|
||||
|
||||
var s = dp.String()
|
||||
if fp.Sign() == 0 {
|
||||
return s
|
||||
}
|
||||
frac := fp.Uint64()
|
||||
trimmed := 0
|
||||
for ; frac%10 == 0; frac /= 10 {
|
||||
trimmed++
|
||||
}
|
||||
return s + "." + fmt.Sprintf("%0"+strconv.FormatUint(uint64(precision-trimmed), 10)+"d", frac)
|
||||
}
|
||||
|
||||
// FromString converts string to a big decimal with specified precision.
|
||||
func FromString(s string, precision int) (*big.Int, error) {
|
||||
parts := strings.SplitN(s, ".", 2)
|
||||
bi, ok := new(big.Int).SetString(parts[0], 10)
|
||||
if !ok {
|
||||
return nil, ErrInvalidFormat
|
||||
}
|
||||
bi.Mul(bi, pow10(precision))
|
||||
if len(parts) == 1 {
|
||||
return bi, nil
|
||||
}
|
||||
|
||||
if len(parts[1]) > precision {
|
||||
return nil, ErrInvalidFormat
|
||||
}
|
||||
fp, ok := new(big.Int).SetString(parts[1], 10)
|
||||
if !ok {
|
||||
return nil, ErrInvalidFormat
|
||||
}
|
||||
fp.Mul(fp, pow10(precision-len(parts[1])))
|
||||
if bi.Sign() == -1 {
|
||||
return bi.Sub(bi, fp), nil
|
||||
}
|
||||
return bi.Add(bi, fp), nil
|
||||
}
|
51
pkg/encoding/fixedn/decimal_test.go
Normal file
51
pkg/encoding/fixedn/decimal_test.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package fixedn
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDecimalFromStringGood(t *testing.T) {
|
||||
var testCases = []struct {
|
||||
bi *big.Int
|
||||
prec int
|
||||
s string
|
||||
}{
|
||||
{big.NewInt(123), 2, "1.23"},
|
||||
{big.NewInt(12300), 2, "123"},
|
||||
{big.NewInt(1234500000), 8, "12.345"},
|
||||
{big.NewInt(-12345), 3, "-12.345"},
|
||||
{big.NewInt(35), 8, "0.00000035"},
|
||||
{big.NewInt(1230), 5, "0.0123"},
|
||||
{big.NewInt(123456789), 20, "0.00000000000123456789"},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.s, func(t *testing.T) {
|
||||
s := ToString(tc.bi, tc.prec)
|
||||
require.Equal(t, tc.s, s)
|
||||
|
||||
bi, err := FromString(s, tc.prec)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.bi, bi)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecimalFromStringBad(t *testing.T) {
|
||||
var errCases = []struct {
|
||||
s string
|
||||
prec int
|
||||
}{
|
||||
{"12A", 1},
|
||||
{"12.345", 2},
|
||||
{"12.3A", 2},
|
||||
}
|
||||
for _, tc := range errCases {
|
||||
t.Run(tc.s, func(t *testing.T) {
|
||||
_, err := FromString(tc.s, tc.prec)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -3,7 +3,6 @@ package fixedn
|
|||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
|
@ -72,36 +71,11 @@ func Fixed8FromFloat(val float64) Fixed8 {
|
|||
// Fixed8FromString parses s which must be a fixed point number
|
||||
// with precision up to 10^-8
|
||||
func Fixed8FromString(s string) (Fixed8, error) {
|
||||
num, err := FixedNFromString(s, precision)
|
||||
num, err := FromString(s, precision)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return Fixed8(num), err
|
||||
}
|
||||
|
||||
// FixedNFromString parses s which must be a fixed point number
|
||||
// with precision 10^-d.
|
||||
func FixedNFromString(s string, precision int) (int64, error) {
|
||||
parts := strings.SplitN(s, ".", 2)
|
||||
d := int64(math.Pow10(precision))
|
||||
ip, err := strconv.ParseInt(parts[0], 10, 64)
|
||||
if err != nil {
|
||||
return 0, errInvalidString
|
||||
} else if len(parts) == 1 {
|
||||
return ip * d, nil
|
||||
}
|
||||
|
||||
fp, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil || fp >= d {
|
||||
return 0, errInvalidString
|
||||
}
|
||||
for i := len(parts[1]); i < precision; i++ {
|
||||
fp *= 10
|
||||
}
|
||||
if ip < 0 {
|
||||
return ip*d - fp, nil
|
||||
}
|
||||
return ip*d + fp, nil
|
||||
return Fixed8(num.Int64()), err
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements the json unmarshaller interface.
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
|
||||
"github.com/nspcc-dev/neo-go/internal/testserdes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
|
@ -85,20 +84,6 @@ func TestFixed8FromString(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestFixedNFromString(t *testing.T) {
|
||||
val := "123.456"
|
||||
num, err := FixedNFromString(val, 3)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 123456, num)
|
||||
|
||||
num, err = FixedNFromString(val, 4)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1234560, num)
|
||||
|
||||
_, err = FixedNFromString(val, 2)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSatoshi(t *testing.T) {
|
||||
satoshif8 := Satoshi()
|
||||
assert.Equal(t, "0.00000001", satoshif8.String())
|
||||
|
|
Loading…
Reference in a new issue