diff --git a/pkg/services/object/sign.go b/pkg/services/object/sign.go index 9b44e9bc..477cd231 100644 --- a/pkg/services/object/sign.go +++ b/pkg/services/object/sign.go @@ -19,15 +19,11 @@ type signService struct { } type searchStreamSigner struct { - key *ecdsa.PrivateKey - - stream object.SearchObjectStreamer + stream *util.MessageStreamer } type getStreamSigner struct { - key *ecdsa.PrivateKey - - stream object.GetObjectStreamer + stream *util.MessageStreamer } type putStreamSigner struct { @@ -37,9 +33,7 @@ type putStreamSigner struct { } type getRangeStreamSigner struct { - key *ecdsa.PrivateKey - - stream object.GetRangeObjectStreamer + stream *util.MessageStreamer } func NewSignService(key *ecdsa.PrivateKey, svc object.Service) object.Service { @@ -56,17 +50,20 @@ func (s *getStreamSigner) Recv() (*object.GetResponse, error) { return nil, errors.Wrap(err, "could not receive response") } - if err := signature.SignServiceMessage(s.key, r); err != nil { - return nil, errors.Wrap(err, "could not sign response") - } - - return r, nil + return r.(*object.GetResponse), nil } func (s *signService) Get(ctx context.Context, req *object.GetRequest) (object.GetObjectStreamer, error) { - resp, err := s.unarySigService.HandleServerStreamRequest(ctx, req, - func(ctx context.Context, req interface{}) (interface{}, error) { - return s.svc.Get(ctx, req.(*object.GetRequest)) + stream, err := s.unarySigService.HandleServerStreamRequest(ctx, req, + func(ctx context.Context, req interface{}) (util.MessageReader, error) { + stream, err := s.svc.Get(ctx, req.(*object.GetRequest)) + if err != nil { + return nil, err + } + + return func() (interface{}, error) { + return stream.Recv() + }, nil }, ) if err != nil { @@ -74,8 +71,7 @@ func (s *signService) Get(ctx context.Context, req *object.GetRequest) (object.G } return &getStreamSigner{ - key: s.key, - stream: resp.(object.GetObjectStreamer), + stream: stream, }, nil } @@ -131,17 +127,20 @@ func (s *searchStreamSigner) Recv() (*object.SearchResponse, error) { return nil, errors.Wrap(err, "could not receive response") } - if err := signature.SignServiceMessage(s.key, r); err != nil { - return nil, errors.Wrap(err, "could not sign response") - } - - return r, nil + return r.(*object.SearchResponse), nil } func (s *signService) Search(ctx context.Context, req *object.SearchRequest) (object.SearchObjectStreamer, error) { - resp, err := s.unarySigService.HandleServerStreamRequest(ctx, req, - func(ctx context.Context, req interface{}) (interface{}, error) { - return s.svc.Search(ctx, req.(*object.SearchRequest)) + stream, err := s.unarySigService.HandleServerStreamRequest(ctx, req, + func(ctx context.Context, req interface{}) (util.MessageReader, error) { + stream, err := s.svc.Search(ctx, req.(*object.SearchRequest)) + if err != nil { + return nil, err + } + + return func() (interface{}, error) { + return stream.Recv() + }, nil }, ) if err != nil { @@ -149,8 +148,7 @@ func (s *signService) Search(ctx context.Context, req *object.SearchRequest) (ob } return &searchStreamSigner{ - key: s.key, - stream: resp.(object.SearchObjectStreamer), + stream: stream, }, nil } @@ -173,17 +171,20 @@ func (s *getRangeStreamSigner) Recv() (*object.GetRangeResponse, error) { return nil, errors.Wrap(err, "could not receive response") } - if err := signature.SignServiceMessage(s.key, r); err != nil { - return nil, errors.Wrap(err, "could not sign response") - } - - return r, nil + return r.(*object.GetRangeResponse), nil } func (s *signService) GetRange(ctx context.Context, req *object.GetRangeRequest) (object.GetRangeObjectStreamer, error) { - resp, err := s.unarySigService.HandleServerStreamRequest(ctx, req, - func(ctx context.Context, req interface{}) (interface{}, error) { - return s.svc.GetRange(ctx, req.(*object.GetRangeRequest)) + stream, err := s.unarySigService.HandleServerStreamRequest(ctx, req, + func(ctx context.Context, req interface{}) (util.MessageReader, error) { + stream, err := s.svc.GetRange(ctx, req.(*object.GetRangeRequest)) + if err != nil { + return nil, err + } + + return func() (interface{}, error) { + return stream.Recv() + }, nil }, ) if err != nil { @@ -191,8 +192,7 @@ func (s *signService) GetRange(ctx context.Context, req *object.GetRangeRequest) } return &getRangeStreamSigner{ - key: s.key, - stream: resp.(object.GetRangeObjectStreamer), + stream: stream, }, nil } diff --git a/pkg/services/util/sign.go b/pkg/services/util/sign.go index 57841f1a..21002025 100644 --- a/pkg/services/util/sign.go +++ b/pkg/services/util/sign.go @@ -14,32 +14,53 @@ type UnarySignService struct { key *ecdsa.PrivateKey } +type ServerStreamHandler func(context.Context, interface{}) (MessageReader, error) + +type MessageReader func() (interface{}, error) + +type MessageStreamer struct { + key *ecdsa.PrivateKey + + recv MessageReader +} + func NewUnarySignService(key *ecdsa.PrivateKey) *UnarySignService { return &UnarySignService{ key: key, } } -func (s *UnarySignService) HandleServerStreamRequest(ctx context.Context, req interface{}, handler UnaryHandler) (interface{}, error) { - return s.verifyAndProc(ctx, req, handler) +func (s *MessageStreamer) Recv() (interface{}, error) { + m, err := s.recv() + if err != nil { + return nil, errors.Wrap(err, "could not receive response message for signing") + } + + if err := signature.SignServiceMessage(s.key, m); err != nil { + return nil, errors.Wrap(err, "could not sign response message") + } + + return m, nil +} + +func (s *UnarySignService) HandleServerStreamRequest(ctx context.Context, req interface{}, handler ServerStreamHandler) (*MessageStreamer, error) { + // verify request signatures + if err := signature.VerifyServiceMessage(req); err != nil { + return nil, errors.Wrap(err, "could not verify request") + } + + msgRdr, err := handler(ctx, req) + if err != nil { + return nil, errors.Wrap(err, "could not create message reader") + } + + return &MessageStreamer{ + key: s.key, + recv: msgRdr, + }, nil } func (s *UnarySignService) HandleUnaryRequest(ctx context.Context, req interface{}, handler UnaryHandler) (interface{}, error) { - // verify and process request - resp, err := s.verifyAndProc(ctx, req, handler) - if err != nil { - return nil, err - } - - // sign the response - if err := signature.SignServiceMessage(s.key, resp); err != nil { - return nil, errors.Wrap(err, "could not sign response") - } - - return resp, nil -} - -func (s *UnarySignService) verifyAndProc(ctx context.Context, req interface{}, handler UnaryHandler) (interface{}, error) { // verify request signatures if err := signature.VerifyServiceMessage(req); err != nil { return nil, errors.Wrap(err, "could not verify request") @@ -51,5 +72,10 @@ func (s *UnarySignService) verifyAndProc(ctx context.Context, req interface{}, h return nil, errors.Wrap(err, "could not handle request") } + // sign the response + if err := signature.SignServiceMessage(s.key, resp); err != nil { + return nil, errors.Wrap(err, "could not sign response") + } + return resp, nil }