From ec0d0274fa936de70b407b3a50c4a32d24261570 Mon Sep 17 00:00:00 2001 From: Dmitrii Stepanov Date: Thu, 9 Mar 2023 11:33:21 +0300 Subject: [PATCH] [#3] signature: Add buffer pool Signed-off-by: Dmitrii Stepanov --- signature/sign.go | 5 +++- signature/verify.go | 49 ++++++++++++++------------------------- util/signature/buffer.go | 29 +++++++++++++++++++++++ util/signature/data.go | 20 +++++++--------- util/signature/options.go | 22 +++++++----------- 5 files changed, 67 insertions(+), 58 deletions(-) create mode 100644 util/signature/buffer.go diff --git a/signature/sign.go b/signature/sign.go index 5ed7870d..49343bd8 100644 --- a/signature/sign.go +++ b/signature/sign.go @@ -101,10 +101,13 @@ func signMessageParts(key *ecdsa.PrivateKey, body, meta, header stableMarshaler, func signServiceMessagePart(key *ecdsa.PrivateKey, part stableMarshaler, sigWrite func(*refs.Signature)) error { var sig *refs.Signature + wrapper := StableMarshalerWrapper{ + SM: part, + } // sign part if err := signature.SignDataWithHandler( key, - &StableMarshalerWrapper{part}, + wrapper, func(s *refs.Signature) { sig = s }, diff --git a/signature/verify.go b/signature/verify.go index 08f91d9d..67e4a820 100644 --- a/signature/verify.go +++ b/signature/verify.go @@ -16,12 +16,6 @@ type signatureProvider interface { GetOriginSignature() *refs.Signature } -type buffers struct { - Body []byte - Meta []byte - Header []byte -} - // VerifyServiceMessage verifies service message. func VerifyServiceMessage(msg interface{}) error { switch v := msg.(type) { @@ -40,23 +34,14 @@ func verifyServiceRequest(v serviceRequest) error { meta := v.GetMetaHeader() verificationHeader := v.GetVerificationHeader() body := serviceMessageBody(v) - buffers := createBuffers(body.StableSize(), meta.StableSize(), verificationHeader.StableSize()) - return verifyServiceRequestRecursive(body, meta, verificationHeader, buffers) + return verifyServiceRequestRecursive(body, meta, verificationHeader) } -func createBuffers(bodySize, metaSize, headerSize int) *buffers { - return &buffers{ - Body: make([]byte, 0, bodySize), - Meta: make([]byte, 0, metaSize), - Header: make([]byte, 0, headerSize), - } -} - -func verifyServiceRequestRecursive(body stableMarshaler, meta *session.RequestMetaHeader, verify *session.RequestVerificationHeader, buffers *buffers) error { +func verifyServiceRequestRecursive(body stableMarshaler, meta *session.RequestMetaHeader, verify *session.RequestVerificationHeader) error { verificationHeaderOrigin := verify.GetOrigin() metaOrigin := meta.GetOrigin() - stop, err := verifyMessageParts(body, meta, verificationHeaderOrigin, verificationHeaderOrigin != nil, verify, buffers) + stop, err := verifyMessageParts(body, meta, verificationHeaderOrigin, verificationHeaderOrigin != nil, verify) if err != nil { return err } @@ -64,21 +49,21 @@ func verifyServiceRequestRecursive(body stableMarshaler, meta *session.RequestMe return nil } - return verifyServiceRequestRecursive(body, metaOrigin, verificationHeaderOrigin, buffers) + return verifyServiceRequestRecursive(body, metaOrigin, verificationHeaderOrigin) } -func verifyMessageParts(body, meta, originHeader stableMarshaler, hasOriginHeader bool, sigProvider signatureProvider, buffers *buffers) (stop bool, err error) { +func verifyMessageParts(body, meta, originHeader stableMarshaler, hasOriginHeader bool, sigProvider signatureProvider) (stop bool, err error) { eg := &errgroup.Group{} eg.Go(func() error { - if err := verifyServiceMessagePart(meta, sigProvider.GetMetaSignature, buffers.Meta); err != nil { + if err := verifyServiceMessagePart(meta, sigProvider.GetMetaSignature); err != nil { return fmt.Errorf("could not verify meta header: %w", err) } return nil }) eg.Go(func() error { - if err := verifyServiceMessagePart(originHeader, sigProvider.GetOriginSignature, buffers.Header); err != nil { + if err := verifyServiceMessagePart(originHeader, sigProvider.GetOriginSignature); err != nil { return fmt.Errorf("could not verify origin of verification header: %w", err) } return nil @@ -86,7 +71,7 @@ func verifyMessageParts(body, meta, originHeader stableMarshaler, hasOriginHeade if !hasOriginHeader { eg.Go(func() error { - if err := verifyServiceMessagePart(body, sigProvider.GetBodySignature, buffers.Body); err != nil { + if err := verifyServiceMessagePart(body, sigProvider.GetBodySignature); err != nil { return fmt.Errorf("could not verify body: %w", err) } return nil @@ -112,15 +97,14 @@ func verifyServiceResponse(v serviceResponse) error { meta := v.GetMetaHeader() verificationHeader := v.GetVerificationHeader() body := serviceMessageBody(v) - buffers := createBuffers(body.StableSize(), meta.StableSize(), verificationHeader.StableSize()) - return verifyServiceResponseRecursive(body, meta, verificationHeader, buffers) + return verifyServiceResponseRecursive(body, meta, verificationHeader) } -func verifyServiceResponseRecursive(body stableMarshaler, meta *session.ResponseMetaHeader, verify *session.ResponseVerificationHeader, buffers *buffers) error { +func verifyServiceResponseRecursive(body stableMarshaler, meta *session.ResponseMetaHeader, verify *session.ResponseVerificationHeader) error { verificationHeaderOrigin := verify.GetOrigin() metaOrigin := meta.GetOrigin() - stop, err := verifyMessageParts(body, meta, verificationHeaderOrigin, verificationHeaderOrigin != nil, verify, buffers) + stop, err := verifyMessageParts(body, meta, verificationHeaderOrigin, verificationHeaderOrigin != nil, verify) if err != nil { return err } @@ -128,13 +112,16 @@ func verifyServiceResponseRecursive(body stableMarshaler, meta *session.Response return nil } - return verifyServiceResponseRecursive(body, metaOrigin, verificationHeaderOrigin, buffers) + return verifyServiceResponseRecursive(body, metaOrigin, verificationHeaderOrigin) } -func verifyServiceMessagePart(part stableMarshaler, sigRdr func() *refs.Signature, buf []byte) error { +func verifyServiceMessagePart(part stableMarshaler, sigRdr func() *refs.Signature) error { + wrapper := StableMarshalerWrapper{ + SM: part, + } + return signature.VerifyDataWithSource( - &StableMarshalerWrapper{part}, + wrapper, sigRdr, - signature.WithBuffer(buf), ) } diff --git a/util/signature/buffer.go b/util/signature/buffer.go new file mode 100644 index 00000000..2a7f1e0c --- /dev/null +++ b/util/signature/buffer.go @@ -0,0 +1,29 @@ +package signature + +import "sync" + +const poolSliceMaxSize = 64 * 1024 + +var buffersPool = sync.Pool{ + New: func() any { + return make([]byte, 0) + }, +} + +func newBufferFromPool(size int) []byte { + result := buffersPool.Get().([]byte) + if cap(result) < size { + result = make([]byte, size) + } else { + result = result[:size] + } + return result +} + +func returnBufferToPool(buf []byte) { + if cap(buf) > poolSliceMaxSize { + return + } + buf = buf[:0] + buffersPool.Put(buf) +} diff --git a/util/signature/data.go b/util/signature/data.go index 6de94668..d2bf1dcd 100644 --- a/util/signature/data.go +++ b/util/signature/data.go @@ -35,7 +35,10 @@ func SignDataWithHandler(key *ecdsa.PrivateKey, src DataSource, handler KeySigna opts[i](cfg) } - data, err := readSignedData(cfg, src) + buffer := newBufferFromPool(src.SignedDataSize()) + defer returnBufferToPool(buffer) + + data, err := src.ReadSignedData(buffer) if err != nil { return err } @@ -61,7 +64,10 @@ func VerifyDataWithSource(dataSrc DataSource, sigSrc KeySignatureSource, opts .. opts[i](cfg) } - data, err := readSignedData(cfg, dataSrc) + buffer := newBufferFromPool(dataSrc.SignedDataSize()) + defer returnBufferToPool(buffer) + + data, err := dataSrc.ReadSignedData(buffer) if err != nil { return err } @@ -76,13 +82,3 @@ func SignData(key *ecdsa.PrivateKey, v DataWithSignature, opts ...SignOption) er func VerifyData(src DataWithSignature, opts ...SignOption) error { return VerifyDataWithSource(src, src.GetSignature, opts...) } - -func readSignedData(cfg *cfg, src DataSource) ([]byte, error) { - size := src.SignedDataSize() - if cfg.buffer == nil || cap(cfg.buffer) < size { - cfg.buffer = make([]byte, size) - } else { - cfg.buffer = cfg.buffer[:size] - } - return src.ReadSignedData(cfg.buffer) -} diff --git a/util/signature/options.go b/util/signature/options.go index 28728dc4..a8a3522e 100644 --- a/util/signature/options.go +++ b/util/signature/options.go @@ -13,7 +13,6 @@ import ( type cfg struct { schemeFixed bool scheme refs.SignatureScheme - buffer []byte } func defaultCfg() *cfg { @@ -36,9 +35,10 @@ func verify(cfg *cfg, data []byte, sig *refs.Signature) error { case refs.ECDSA_RFC6979_SHA256: return crypto.VerifyRFC6979(pub, data, sig.GetSign()) case refs.ECDSA_RFC6979_SHA256_WALLET_CONNECT: - buf := make([]byte, base64.StdEncoding.EncodedLen(len(data))) - base64.StdEncoding.Encode(buf, data) - if !walletconnect.Verify(pub, buf, sig.GetSign()) { + buffer := newBufferFromPool(base64.StdEncoding.EncodedLen(len(data))) + defer returnBufferToPool(buffer) + base64.StdEncoding.Encode(buffer, data) + if !walletconnect.Verify(pub, buffer, sig.GetSign()) { return crypto.ErrInvalidSignature } return nil @@ -54,9 +54,10 @@ func sign(cfg *cfg, key *ecdsa.PrivateKey, data []byte) ([]byte, error) { case refs.ECDSA_RFC6979_SHA256: return crypto.SignRFC6979(key, data) case refs.ECDSA_RFC6979_SHA256_WALLET_CONNECT: - buf := make([]byte, base64.StdEncoding.EncodedLen(len(data))) - base64.StdEncoding.Encode(buf, data) - return walletconnect.Sign(key, buf) + buffer := newBufferFromPool(base64.StdEncoding.EncodedLen(len(data))) + defer returnBufferToPool(buffer) + base64.StdEncoding.Encode(buffer, data) + return walletconnect.Sign(key, buffer) default: panic(fmt.Sprintf("unsupported scheme %s", cfg.scheme)) } @@ -69,13 +70,6 @@ func SignWithRFC6979() SignOption { } } -// WithBuffer allows providing pre-allocated buffer for signature verification. -func WithBuffer(buf []byte) SignOption { - return func(c *cfg) { - c.buffer = buf - } -} - func SignWithWalletConnect() SignOption { return func(c *cfg) { c.scheme = refs.ECDSA_RFC6979_SHA256_WALLET_CONNECT