[#3] signature: Add buffer pool

Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
support/v2.15
Dmitrii Stepanov 2023-03-09 11:33:21 +03:00
parent 73fde0e37c
commit ec0d0274fa
5 changed files with 67 additions and 58 deletions

View File

@ -101,10 +101,13 @@ func signMessageParts(key *ecdsa.PrivateKey, body, meta, header stableMarshaler,
func signServiceMessagePart(key *ecdsa.PrivateKey, part stableMarshaler, sigWrite func(*refs.Signature)) error { func signServiceMessagePart(key *ecdsa.PrivateKey, part stableMarshaler, sigWrite func(*refs.Signature)) error {
var sig *refs.Signature var sig *refs.Signature
wrapper := StableMarshalerWrapper{
SM: part,
}
// sign part // sign part
if err := signature.SignDataWithHandler( if err := signature.SignDataWithHandler(
key, key,
&StableMarshalerWrapper{part}, wrapper,
func(s *refs.Signature) { func(s *refs.Signature) {
sig = s sig = s
}, },

View File

@ -16,12 +16,6 @@ type signatureProvider interface {
GetOriginSignature() *refs.Signature GetOriginSignature() *refs.Signature
} }
type buffers struct {
Body []byte
Meta []byte
Header []byte
}
// VerifyServiceMessage verifies service message. // VerifyServiceMessage verifies service message.
func VerifyServiceMessage(msg interface{}) error { func VerifyServiceMessage(msg interface{}) error {
switch v := msg.(type) { switch v := msg.(type) {
@ -40,23 +34,14 @@ func verifyServiceRequest(v serviceRequest) error {
meta := v.GetMetaHeader() meta := v.GetMetaHeader()
verificationHeader := v.GetVerificationHeader() verificationHeader := v.GetVerificationHeader()
body := serviceMessageBody(v) body := serviceMessageBody(v)
buffers := createBuffers(body.StableSize(), meta.StableSize(), verificationHeader.StableSize()) return verifyServiceRequestRecursive(body, meta, verificationHeader)
return verifyServiceRequestRecursive(body, meta, verificationHeader, buffers)
} }
func createBuffers(bodySize, metaSize, headerSize int) *buffers { func verifyServiceRequestRecursive(body stableMarshaler, meta *session.RequestMetaHeader, verify *session.RequestVerificationHeader) error {
return &buffers{
Body: make([]byte, 0, bodySize),
Meta: make([]byte, 0, metaSize),
Header: make([]byte, 0, headerSize),
}
}
func verifyServiceRequestRecursive(body stableMarshaler, meta *session.RequestMetaHeader, verify *session.RequestVerificationHeader, buffers *buffers) error {
verificationHeaderOrigin := verify.GetOrigin() verificationHeaderOrigin := verify.GetOrigin()
metaOrigin := meta.GetOrigin() metaOrigin := meta.GetOrigin()
stop, err := verifyMessageParts(body, meta, verificationHeaderOrigin, verificationHeaderOrigin != nil, verify, buffers) stop, err := verifyMessageParts(body, meta, verificationHeaderOrigin, verificationHeaderOrigin != nil, verify)
if err != nil { if err != nil {
return err return err
} }
@ -64,21 +49,21 @@ func verifyServiceRequestRecursive(body stableMarshaler, meta *session.RequestMe
return nil return nil
} }
return verifyServiceRequestRecursive(body, metaOrigin, verificationHeaderOrigin, buffers) return verifyServiceRequestRecursive(body, metaOrigin, verificationHeaderOrigin)
} }
func verifyMessageParts(body, meta, originHeader stableMarshaler, hasOriginHeader bool, sigProvider signatureProvider, buffers *buffers) (stop bool, err error) { func verifyMessageParts(body, meta, originHeader stableMarshaler, hasOriginHeader bool, sigProvider signatureProvider) (stop bool, err error) {
eg := &errgroup.Group{} eg := &errgroup.Group{}
eg.Go(func() error { eg.Go(func() error {
if err := verifyServiceMessagePart(meta, sigProvider.GetMetaSignature, buffers.Meta); err != nil { if err := verifyServiceMessagePart(meta, sigProvider.GetMetaSignature); err != nil {
return fmt.Errorf("could not verify meta header: %w", err) return fmt.Errorf("could not verify meta header: %w", err)
} }
return nil return nil
}) })
eg.Go(func() error { eg.Go(func() error {
if err := verifyServiceMessagePart(originHeader, sigProvider.GetOriginSignature, buffers.Header); err != nil { if err := verifyServiceMessagePart(originHeader, sigProvider.GetOriginSignature); err != nil {
return fmt.Errorf("could not verify origin of verification header: %w", err) return fmt.Errorf("could not verify origin of verification header: %w", err)
} }
return nil return nil
@ -86,7 +71,7 @@ func verifyMessageParts(body, meta, originHeader stableMarshaler, hasOriginHeade
if !hasOriginHeader { if !hasOriginHeader {
eg.Go(func() error { eg.Go(func() error {
if err := verifyServiceMessagePart(body, sigProvider.GetBodySignature, buffers.Body); err != nil { if err := verifyServiceMessagePart(body, sigProvider.GetBodySignature); err != nil {
return fmt.Errorf("could not verify body: %w", err) return fmt.Errorf("could not verify body: %w", err)
} }
return nil return nil
@ -112,15 +97,14 @@ func verifyServiceResponse(v serviceResponse) error {
meta := v.GetMetaHeader() meta := v.GetMetaHeader()
verificationHeader := v.GetVerificationHeader() verificationHeader := v.GetVerificationHeader()
body := serviceMessageBody(v) body := serviceMessageBody(v)
buffers := createBuffers(body.StableSize(), meta.StableSize(), verificationHeader.StableSize()) return verifyServiceResponseRecursive(body, meta, verificationHeader)
return verifyServiceResponseRecursive(body, meta, verificationHeader, buffers)
} }
func verifyServiceResponseRecursive(body stableMarshaler, meta *session.ResponseMetaHeader, verify *session.ResponseVerificationHeader, buffers *buffers) error { func verifyServiceResponseRecursive(body stableMarshaler, meta *session.ResponseMetaHeader, verify *session.ResponseVerificationHeader) error {
verificationHeaderOrigin := verify.GetOrigin() verificationHeaderOrigin := verify.GetOrigin()
metaOrigin := meta.GetOrigin() metaOrigin := meta.GetOrigin()
stop, err := verifyMessageParts(body, meta, verificationHeaderOrigin, verificationHeaderOrigin != nil, verify, buffers) stop, err := verifyMessageParts(body, meta, verificationHeaderOrigin, verificationHeaderOrigin != nil, verify)
if err != nil { if err != nil {
return err return err
} }
@ -128,13 +112,16 @@ func verifyServiceResponseRecursive(body stableMarshaler, meta *session.Response
return nil return nil
} }
return verifyServiceResponseRecursive(body, metaOrigin, verificationHeaderOrigin, buffers) return verifyServiceResponseRecursive(body, metaOrigin, verificationHeaderOrigin)
} }
func verifyServiceMessagePart(part stableMarshaler, sigRdr func() *refs.Signature, buf []byte) error { func verifyServiceMessagePart(part stableMarshaler, sigRdr func() *refs.Signature) error {
wrapper := StableMarshalerWrapper{
SM: part,
}
return signature.VerifyDataWithSource( return signature.VerifyDataWithSource(
&StableMarshalerWrapper{part}, wrapper,
sigRdr, sigRdr,
signature.WithBuffer(buf),
) )
} }

View File

@ -0,0 +1,29 @@
package signature
import "sync"
const poolSliceMaxSize = 64 * 1024
var buffersPool = sync.Pool{
New: func() any {
return make([]byte, 0)
},
}
func newBufferFromPool(size int) []byte {
result := buffersPool.Get().([]byte)
if cap(result) < size {
result = make([]byte, size)
} else {
result = result[:size]
}
return result
}
func returnBufferToPool(buf []byte) {
if cap(buf) > poolSliceMaxSize {
return
}
buf = buf[:0]
buffersPool.Put(buf)
}

View File

@ -35,7 +35,10 @@ func SignDataWithHandler(key *ecdsa.PrivateKey, src DataSource, handler KeySigna
opts[i](cfg) opts[i](cfg)
} }
data, err := readSignedData(cfg, src) buffer := newBufferFromPool(src.SignedDataSize())
defer returnBufferToPool(buffer)
data, err := src.ReadSignedData(buffer)
if err != nil { if err != nil {
return err return err
} }
@ -61,7 +64,10 @@ func VerifyDataWithSource(dataSrc DataSource, sigSrc KeySignatureSource, opts ..
opts[i](cfg) opts[i](cfg)
} }
data, err := readSignedData(cfg, dataSrc) buffer := newBufferFromPool(dataSrc.SignedDataSize())
defer returnBufferToPool(buffer)
data, err := dataSrc.ReadSignedData(buffer)
if err != nil { if err != nil {
return err return err
} }
@ -76,13 +82,3 @@ func SignData(key *ecdsa.PrivateKey, v DataWithSignature, opts ...SignOption) er
func VerifyData(src DataWithSignature, opts ...SignOption) error { func VerifyData(src DataWithSignature, opts ...SignOption) error {
return VerifyDataWithSource(src, src.GetSignature, opts...) return VerifyDataWithSource(src, src.GetSignature, opts...)
} }
func readSignedData(cfg *cfg, src DataSource) ([]byte, error) {
size := src.SignedDataSize()
if cfg.buffer == nil || cap(cfg.buffer) < size {
cfg.buffer = make([]byte, size)
} else {
cfg.buffer = cfg.buffer[:size]
}
return src.ReadSignedData(cfg.buffer)
}

View File

@ -13,7 +13,6 @@ import (
type cfg struct { type cfg struct {
schemeFixed bool schemeFixed bool
scheme refs.SignatureScheme scheme refs.SignatureScheme
buffer []byte
} }
func defaultCfg() *cfg { func defaultCfg() *cfg {
@ -36,9 +35,10 @@ func verify(cfg *cfg, data []byte, sig *refs.Signature) error {
case refs.ECDSA_RFC6979_SHA256: case refs.ECDSA_RFC6979_SHA256:
return crypto.VerifyRFC6979(pub, data, sig.GetSign()) return crypto.VerifyRFC6979(pub, data, sig.GetSign())
case refs.ECDSA_RFC6979_SHA256_WALLET_CONNECT: case refs.ECDSA_RFC6979_SHA256_WALLET_CONNECT:
buf := make([]byte, base64.StdEncoding.EncodedLen(len(data))) buffer := newBufferFromPool(base64.StdEncoding.EncodedLen(len(data)))
base64.StdEncoding.Encode(buf, data) defer returnBufferToPool(buffer)
if !walletconnect.Verify(pub, buf, sig.GetSign()) { base64.StdEncoding.Encode(buffer, data)
if !walletconnect.Verify(pub, buffer, sig.GetSign()) {
return crypto.ErrInvalidSignature return crypto.ErrInvalidSignature
} }
return nil return nil
@ -54,9 +54,10 @@ func sign(cfg *cfg, key *ecdsa.PrivateKey, data []byte) ([]byte, error) {
case refs.ECDSA_RFC6979_SHA256: case refs.ECDSA_RFC6979_SHA256:
return crypto.SignRFC6979(key, data) return crypto.SignRFC6979(key, data)
case refs.ECDSA_RFC6979_SHA256_WALLET_CONNECT: case refs.ECDSA_RFC6979_SHA256_WALLET_CONNECT:
buf := make([]byte, base64.StdEncoding.EncodedLen(len(data))) buffer := newBufferFromPool(base64.StdEncoding.EncodedLen(len(data)))
base64.StdEncoding.Encode(buf, data) defer returnBufferToPool(buffer)
return walletconnect.Sign(key, buf) base64.StdEncoding.Encode(buffer, data)
return walletconnect.Sign(key, buffer)
default: default:
panic(fmt.Sprintf("unsupported scheme %s", cfg.scheme)) panic(fmt.Sprintf("unsupported scheme %s", cfg.scheme))
} }
@ -69,13 +70,6 @@ func SignWithRFC6979() SignOption {
} }
} }
// WithBuffer allows providing pre-allocated buffer for signature verification.
func WithBuffer(buf []byte) SignOption {
return func(c *cfg) {
c.buffer = buf
}
}
func SignWithWalletConnect() SignOption { func SignWithWalletConnect() SignOption {
return func(c *cfg) { return func(c *cfg) {
c.scheme = refs.ECDSA_RFC6979_SHA256_WALLET_CONNECT c.scheme = refs.ECDSA_RFC6979_SHA256_WALLET_CONNECT