restic: Make JSON unmarshal for ID more efficient

This commit reduces several allocations in UnmarshalJSON() by decoding
the hex string directly in a single step.
This commit is contained in:
Alexander Neumann 2019-03-22 20:30:29 +01:00
parent 2022355800
commit 203d775190
2 changed files with 64 additions and 6 deletions

View file

@ -5,6 +5,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
@ -101,13 +102,33 @@ func (id ID) MarshalJSON() ([]byte, error) {
// UnmarshalJSON parses the JSON-encoded data and stores the result in id. // UnmarshalJSON parses the JSON-encoded data and stores the result in id.
func (id *ID) UnmarshalJSON(b []byte) error { func (id *ID) UnmarshalJSON(b []byte) error {
var s string // check string length
err := json.Unmarshal(b, &s) if len(b) < 2 {
if err != nil { return fmt.Errorf("invalid ID: %q", b)
return errors.Wrap(err, "Unmarshal")
} }
_, err = hex.Decode(id[:], []byte(s)) if len(b)%2 != 0 {
return fmt.Errorf("invalid ID length: %q", b)
}
// check string delimiters
if b[0] != '"' && b[0] != '\'' {
return fmt.Errorf("invalid start of string: %q", b[0])
}
last := len(b) - 1
if b[0] != b[last] {
return fmt.Errorf("starting string delimiter (%q) does not match end (%q)", b[0], b[last])
}
// strip JSON string delimiters
b = b[1:last]
if len(b) != 2*len(id) {
return fmt.Errorf("invalid length for ID")
}
_, err := hex.Decode(id[:], b)
if err != nil { if err != nil {
return errors.Wrap(err, "hex.Decode") return errors.Wrap(err, "hex.Decode")
} }

View file

@ -51,10 +51,47 @@ func TestID(t *testing.T) {
var id3 ID var id3 ID
err = id3.UnmarshalJSON(buf) err = id3.UnmarshalJSON(buf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatalf("error for %q: %v", buf, err)
} }
if !reflect.DeepEqual(id, id3) { if !reflect.DeepEqual(id, id3) {
t.Error("ids are not equal") t.Error("ids are not equal")
} }
} }
} }
func TestIDUnmarshal(t *testing.T) {
var tests = []struct {
s string
valid bool
}{
{`"`, false},
{`""`, false},
{`'`, false},
{`"`, false},
{`"c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4"`, false},
{`"c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f"`, false},
{`"c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2"`, true},
}
wantID, err := ParseID("c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2")
if err != nil {
t.Fatal(err)
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
id := &ID{}
err := id.UnmarshalJSON([]byte(test.s))
if test.valid && err != nil {
t.Fatal(err)
}
if !test.valid && err == nil {
t.Fatalf("want error for invalid value, got nil")
}
if test.valid && !id.Equal(wantID) {
t.Fatalf("wrong ID returned, want %s, got %s", wantID, id)
}
})
}
}