diff --git a/pkg/services/object/sign.go b/pkg/services/object/sign.go index 32119f0f9..95decd048 100644 --- a/pkg/services/object/sign.go +++ b/pkg/services/object/sign.go @@ -5,7 +5,6 @@ import ( "crypto/ecdsa" "github.com/nspcc-dev/neofs-api-go/v2/object" - "github.com/nspcc-dev/neofs-api-go/v2/signature" "github.com/nspcc-dev/neofs-node/pkg/services/util" "github.com/pkg/errors" ) @@ -19,21 +18,19 @@ type signService struct { } type searchStreamSigner struct { - stream *util.MessageStreamer + stream *util.ResponseMessageStreamer } type getStreamSigner struct { - stream *util.MessageStreamer + stream *util.ResponseMessageStreamer } type putStreamSigner struct { - key *ecdsa.PrivateKey - - stream object.PutObjectStreamer + stream *util.RequestMessageStreamer } type getRangeStreamSigner struct { - stream *util.MessageStreamer + stream *util.ResponseMessageStreamer } func NewSignService(key *ecdsa.PrivateKey, svc object.Service) object.Service { @@ -55,7 +52,7 @@ func (s *getStreamSigner) Recv() (*object.GetResponse, error) { func (s *signService) Get(ctx context.Context, req *object.GetRequest) (object.GetObjectStreamer, error) { stream, err := s.sigSvc.HandleServerStreamRequest(ctx, req, - func(ctx context.Context, req interface{}) (util.MessageReader, error) { + func(ctx context.Context, req interface{}) (util.ResponseMessageReader, error) { stream, err := s.svc.Get(ctx, req.(*object.GetRequest)) if err != nil { return nil, err @@ -76,10 +73,6 @@ func (s *signService) Get(ctx context.Context, req *object.GetRequest) (object.G } func (s *putStreamSigner) Send(req *object.PutRequest) error { - if err := signature.VerifyServiceMessage(req); err != nil { - return errors.Wrap(err, "could not verify request") - } - return s.stream.Send(req) } @@ -89,11 +82,7 @@ func (s *putStreamSigner) CloseAndRecv() (*object.PutResponse, error) { return nil, errors.Wrap(err, "could not receive response") } - if err := signature.SignServiceMessage(s.key, r); err != nil { - return nil, errors.Wrap(err, "could not sign response") - } - - return r, nil + return r.(*object.PutResponse), nil } func (s *signService) Put(ctx context.Context) (object.PutObjectStreamer, error) { @@ -103,8 +92,14 @@ func (s *signService) Put(ctx context.Context) (object.PutObjectStreamer, error) } return &putStreamSigner{ - key: s.key, - stream: stream, + stream: s.sigSvc.CreateRequestStreamer( + func(req interface{}) error { + return stream.Send(req.(*object.PutRequest)) + }, + func() (interface{}, error) { + return stream.CloseAndRecv() + }, + ), }, nil } @@ -132,7 +127,7 @@ func (s *searchStreamSigner) Recv() (*object.SearchResponse, error) { func (s *signService) Search(ctx context.Context, req *object.SearchRequest) (object.SearchObjectStreamer, error) { stream, err := s.sigSvc.HandleServerStreamRequest(ctx, req, - func(ctx context.Context, req interface{}) (util.MessageReader, error) { + func(ctx context.Context, req interface{}) (util.ResponseMessageReader, error) { stream, err := s.svc.Search(ctx, req.(*object.SearchRequest)) if err != nil { return nil, err @@ -176,7 +171,7 @@ func (s *getRangeStreamSigner) Recv() (*object.GetRangeResponse, error) { func (s *signService) GetRange(ctx context.Context, req *object.GetRangeRequest) (object.GetRangeObjectStreamer, error) { stream, err := s.sigSvc.HandleServerStreamRequest(ctx, req, - func(ctx context.Context, req interface{}) (util.MessageReader, error) { + func(ctx context.Context, req interface{}) (util.ResponseMessageReader, error) { stream, err := s.svc.GetRange(ctx, req.(*object.GetRangeRequest)) if err != nil { return nil, err diff --git a/pkg/services/util/sign.go b/pkg/services/util/sign.go index 03ef4088b..0ff416dc0 100644 --- a/pkg/services/util/sign.go +++ b/pkg/services/util/sign.go @@ -14,14 +14,26 @@ type SignService struct { key *ecdsa.PrivateKey } -type ServerStreamHandler func(context.Context, interface{}) (MessageReader, error) +type ServerStreamHandler func(context.Context, interface{}) (ResponseMessageReader, error) -type MessageReader func() (interface{}, error) +type ResponseMessageReader func() (interface{}, error) -type MessageStreamer struct { +type ResponseMessageStreamer struct { key *ecdsa.PrivateKey - recv MessageReader + recv ResponseMessageReader +} + +type RequestMessageWriter func(interface{}) error + +type ClientStreamCloser func() (interface{}, error) + +type RequestMessageStreamer struct { + key *ecdsa.PrivateKey + + send RequestMessageWriter + + close ClientStreamCloser } func NewUnarySignService(key *ecdsa.PrivateKey) *SignService { @@ -30,7 +42,37 @@ func NewUnarySignService(key *ecdsa.PrivateKey) *SignService { } } -func (s *MessageStreamer) Recv() (interface{}, error) { +func (s *RequestMessageStreamer) Send(req interface{}) error { + // verify request signatures + if err := signature.VerifyServiceMessage(req); err != nil { + return errors.Wrap(err, "could not verify request") + } + + return s.send(req) +} + +func (s *RequestMessageStreamer) CloseAndRecv() (interface{}, error) { + resp, err := s.close() + if err != nil { + return nil, errors.Wrap(err, "could not close stream and receive response") + } + + if err := signature.SignServiceMessage(s.key, resp); err != nil { + return nil, errors.Wrap(err, "could not sign response") + } + + return resp, nil +} + +func (s *SignService) CreateRequestStreamer(sender RequestMessageWriter, closer ClientStreamCloser) *RequestMessageStreamer { + return &RequestMessageStreamer{ + key: s.key, + send: sender, + close: closer, + } +} + +func (s *ResponseMessageStreamer) Recv() (interface{}, error) { m, err := s.recv() if err != nil { return nil, errors.Wrap(err, "could not receive response message for signing") @@ -43,7 +85,7 @@ func (s *MessageStreamer) Recv() (interface{}, error) { return m, nil } -func (s *SignService) HandleServerStreamRequest(ctx context.Context, req interface{}, handler ServerStreamHandler) (*MessageStreamer, error) { +func (s *SignService) HandleServerStreamRequest(ctx context.Context, req interface{}, handler ServerStreamHandler) (*ResponseMessageStreamer, error) { // verify request signatures if err := signature.VerifyServiceMessage(req); err != nil { return nil, errors.Wrap(err, "could not verify request") @@ -54,7 +96,7 @@ func (s *SignService) HandleServerStreamRequest(ctx context.Context, req interfa return nil, errors.Wrap(err, "could not create message reader") } - return &MessageStreamer{ + return &ResponseMessageStreamer{ key: s.key, recv: msgRdr, }, nil