diff --git a/pkg/services/accounting/sign.go b/pkg/services/accounting/sign.go index 9efb063f..be7b08a3 100644 --- a/pkg/services/accounting/sign.go +++ b/pkg/services/accounting/sign.go @@ -23,5 +23,5 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server { func (s *signService) Balance(ctx context.Context, req *accounting.BalanceRequest) (*accounting.BalanceResponse, error) { resp, err := util.EnsureNonNilResponse(s.svc.Balance(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } diff --git a/pkg/services/container/sign.go b/pkg/services/container/sign.go index 55125335..b336f19c 100644 --- a/pkg/services/container/sign.go +++ b/pkg/services/container/sign.go @@ -24,62 +24,62 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server { func (s *signService) Put(ctx context.Context, req *container.PutRequest) (*container.PutResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(container.PutResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.Put(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } func (s *signService) Delete(ctx context.Context, req *container.DeleteRequest) (*container.DeleteResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(container.DeleteResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.Delete(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } func (s *signService) Get(ctx context.Context, req *container.GetRequest) (*container.GetResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(container.GetResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.Get(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } func (s *signService) List(ctx context.Context, req *container.ListRequest) (*container.ListResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(container.ListResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.List(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } func (s *signService) SetExtendedACL(ctx context.Context, req *container.SetExtendedACLRequest) (*container.SetExtendedACLResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(container.SetExtendedACLResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.SetExtendedACL(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } func (s *signService) GetExtendedACL(ctx context.Context, req *container.GetExtendedACLRequest) (*container.GetExtendedACLResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(container.GetExtendedACLResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.GetExtendedACL(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } func (s *signService) AnnounceUsedSpace(ctx context.Context, req *container.AnnounceUsedSpaceRequest) (*container.AnnounceUsedSpaceResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(container.AnnounceUsedSpaceResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.AnnounceUsedSpace(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } diff --git a/pkg/services/netmap/sign.go b/pkg/services/netmap/sign.go index e665519e..2d01164a 100644 --- a/pkg/services/netmap/sign.go +++ b/pkg/services/netmap/sign.go @@ -26,26 +26,26 @@ func (s *signService) LocalNodeInfo( req *netmap.LocalNodeInfoRequest) (*netmap.LocalNodeInfoResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(netmap.LocalNodeInfoResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.LocalNodeInfo(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } func (s *signService) NetworkInfo(ctx context.Context, req *netmap.NetworkInfoRequest) (*netmap.NetworkInfoResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(netmap.NetworkInfoResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.NetworkInfo(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } func (s *signService) Snapshot(ctx context.Context, req *netmap.SnapshotRequest) (*netmap.SnapshotResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(netmap.SnapshotResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.Snapshot(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } diff --git a/pkg/services/object/sign.go b/pkg/services/object/sign.go index f160b4c5..c516872e 100644 --- a/pkg/services/object/sign.go +++ b/pkg/services/object/sign.go @@ -18,27 +18,30 @@ type SignService struct { } type searchStreamSigner struct { - util.ServerStream - - respWriter util.ResponseMessageWriter + SearchStream + statusSupported bool + sigSvc *util.SignService nonEmptyResp bool // set on first Send call } type getStreamSigner struct { - util.ServerStream - - respWriter util.ResponseMessageWriter + GetObjectStream + statusSupported bool + sigSvc *util.SignService } type putStreamSigner struct { - stream *util.RequestMessageStreamer + sigSvc *util.SignService + stream PutObjectStream + statusSupported bool + err error } type getRangeStreamSigner struct { - util.ServerStream - - respWriter util.ResponseMessageWriter + GetObjectRangeStream + statusSupported bool + sigSvc *util.SignService } func NewSignService(key *ecdsa.PrivateKey, svc ServiceServer) *SignService { @@ -50,37 +53,50 @@ func NewSignService(key *ecdsa.PrivateKey, svc ServiceServer) *SignService { } func (s *getStreamSigner) Send(resp *object.GetResponse) error { - return s.respWriter(resp) + if err := s.sigSvc.SignResponse(s.statusSupported, resp, nil); err != nil { + return err + } + return s.GetObjectStream.Send(resp) } func (s *SignService) Get(req *object.GetRequest, stream GetObjectStream) error { - return s.sigSvc.HandleServerStreamRequest(req, - func(resp util.ResponseMessage) error { - return stream.Send(resp.(*object.GetResponse)) - }, - func() util.ResponseMessage { - return new(object.GetResponse) - }, - func(respWriter util.ResponseMessageWriter) error { - return s.svc.Get(req, &getStreamSigner{ - ServerStream: stream, - respWriter: respWriter, - }) - }, - ) + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(object.GetResponse) + _ = s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) + return stream.Send(resp) + } + + return s.svc.Get(req, &getStreamSigner{ + GetObjectStream: stream, + sigSvc: s.sigSvc, + statusSupported: util.IsStatusSupported(req), + }) } func (s *putStreamSigner) Send(ctx context.Context, req *object.PutRequest) error { - return s.stream.Send(ctx, req) + s.statusSupported = util.IsStatusSupported(req) + + if s.err = s.sigSvc.VerifyRequest(req); s.err != nil { + return util.ErrAbortStream + } + if s.err = s.stream.Send(ctx, req); s.err != nil { + return util.ErrAbortStream + } + return nil } -func (s *putStreamSigner) CloseAndRecv(ctx context.Context) (*object.PutResponse, error) { - r, err := s.stream.CloseAndRecv(ctx) - if err != nil { - return nil, fmt.Errorf("could not receive response: %w", err) +func (s *putStreamSigner) CloseAndRecv(ctx context.Context) (resp *object.PutResponse, err error) { + if s.err != nil { + err = s.err + resp = new(object.PutResponse) + } else { + resp, err = s.stream.CloseAndRecv(ctx) + if err != nil { + return nil, fmt.Errorf("could not close stream and receive response: %w", err) + } } - return r.(*object.PutResponse), nil + return resp, s.sigSvc.SignResponse(s.statusSupported, resp, err) } func (s *SignService) Put() (PutObjectStream, error) { @@ -90,108 +106,96 @@ func (s *SignService) Put() (PutObjectStream, error) { } return &putStreamSigner{ - stream: s.sigSvc.CreateRequestStreamer( - func(ctx context.Context, req any) error { - return stream.Send(ctx, req.(*object.PutRequest)) - }, - func(ctx context.Context) (util.ResponseMessage, error) { - return stream.CloseAndRecv(ctx) - }, - func() util.ResponseMessage { - return new(object.PutResponse) - }, - ), + stream: stream, + sigSvc: s.sigSvc, }, nil } func (s *SignService) Head(ctx context.Context, req *object.HeadRequest) (*object.HeadResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(object.HeadResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.Head(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } func (s *SignService) PutSingle(ctx context.Context, req *object.PutSingleRequest) (*object.PutSingleResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(object.PutSingleResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.PutSingle(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } func (s *searchStreamSigner) Send(resp *object.SearchResponse) error { s.nonEmptyResp = true - return s.respWriter(resp) + if err := s.sigSvc.SignResponse(s.statusSupported, resp, nil); err != nil { + return err + } + return s.SearchStream.Send(resp) } func (s *SignService) Search(req *object.SearchRequest, stream SearchStream) error { - return s.sigSvc.HandleServerStreamRequest(req, - func(resp util.ResponseMessage) error { - return stream.Send(resp.(*object.SearchResponse)) - }, - func() util.ResponseMessage { - return new(object.SearchResponse) - }, - func(respWriter util.ResponseMessageWriter) error { - stream := &searchStreamSigner{ - ServerStream: stream, - respWriter: respWriter, - } + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(object.SearchResponse) + _ = s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) + return stream.Send(resp) + } - err := s.svc.Search(req, stream) - - if err == nil && !stream.nonEmptyResp { - // The higher component does not write any response in the case of an empty result (which is correct). - // With the introduction of status returns at least one answer must be signed and sent to the client. - // This approach is supported by clients who do not know how to work with statuses (one could make - // a switch according to the protocol version from the request, but the costs of sending an empty - // answer can be neglected due to the gradual refusal to use the "old" clients). - return stream.Send(new(object.SearchResponse)) - } - - return err - }, - ) + ss := &searchStreamSigner{ + SearchStream: stream, + sigSvc: s.sigSvc, + statusSupported: util.IsStatusSupported(req), + } + err := s.svc.Search(req, ss) + if err == nil && !ss.nonEmptyResp { + // The higher component does not write any response in the case of an empty result (which is correct). + // With the introduction of status returns at least one answer must be signed and sent to the client. + // This approach is supported by clients who do not know how to work with statuses (one could make + // a switch according to the protocol version from the request, but the costs of sending an empty + // answer can be neglected due to the gradual refusal to use the "old" clients). + return stream.Send(new(object.SearchResponse)) + } + return err } func (s *SignService) Delete(ctx context.Context, req *object.DeleteRequest) (*object.DeleteResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(object.DeleteResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.Delete(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } func (s *getRangeStreamSigner) Send(resp *object.GetRangeResponse) error { - return s.respWriter(resp) + if err := s.sigSvc.SignResponse(s.statusSupported, resp, nil); err != nil { + return err + } + return s.GetObjectRangeStream.Send(resp) } func (s *SignService) GetRange(req *object.GetRangeRequest, stream GetObjectRangeStream) error { - return s.sigSvc.HandleServerStreamRequest(req, - func(resp util.ResponseMessage) error { - return stream.Send(resp.(*object.GetRangeResponse)) - }, - func() util.ResponseMessage { - return new(object.GetRangeResponse) - }, - func(respWriter util.ResponseMessageWriter) error { - return s.svc.GetRange(req, &getRangeStreamSigner{ - ServerStream: stream, - respWriter: respWriter, - }) - }, - ) + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(object.GetRangeResponse) + _ = s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) + return stream.Send(resp) + } + + return s.svc.GetRange(req, &getRangeStreamSigner{ + GetObjectRangeStream: stream, + sigSvc: s.sigSvc, + statusSupported: util.IsStatusSupported(req), + }) } func (s *SignService) GetRangeHash(ctx context.Context, req *object.GetRangeHashRequest) (*object.GetRangeHashResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(object.GetRangeHashResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.GetRangeHash(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } diff --git a/pkg/services/session/sign.go b/pkg/services/session/sign.go index 33c0a531..ffce0621 100644 --- a/pkg/services/session/sign.go +++ b/pkg/services/session/sign.go @@ -24,8 +24,8 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server { func (s *signService) Create(ctx context.Context, req *session.CreateRequest) (*session.CreateResponse, error) { if err := s.sigSvc.VerifyRequest(req); err != nil { resp := new(session.CreateResponse) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } resp, err := util.EnsureNonNilResponse(s.svc.Create(ctx, req)) - return resp, s.sigSvc.SignResponse(req, resp, err) + return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) } diff --git a/pkg/services/util/sign.go b/pkg/services/util/sign.go index e4157566..a3e0c946 100644 --- a/pkg/services/util/sign.go +++ b/pkg/services/util/sign.go @@ -1,7 +1,6 @@ package util import ( - "context" "crypto/ecdsa" "errors" "fmt" @@ -21,163 +20,23 @@ type ResponseMessage interface { SetMetaHeader(*session.ResponseMetaHeader) } -type UnaryHandler func(context.Context, any) (ResponseMessage, error) - type SignService struct { key *ecdsa.PrivateKey } -type ResponseMessageWriter func(ResponseMessage) error - -type ServerStreamHandler func(context.Context, any) (ResponseMessageReader, error) - -type ResponseMessageReader func() (ResponseMessage, error) - var ErrAbortStream = errors.New("abort message stream") -type ResponseConstructor func() ResponseMessage - -type RequestMessageWriter func(context.Context, any) error - -type ClientStreamCloser func(context.Context) (ResponseMessage, error) - -type RequestMessageStreamer struct { - key *ecdsa.PrivateKey - - send RequestMessageWriter - - close ClientStreamCloser - - respCons ResponseConstructor - - statusSupported bool - - sendErr error -} - func NewUnarySignService(key *ecdsa.PrivateKey) *SignService { return &SignService{ key: key, } } -func (s *RequestMessageStreamer) Send(ctx context.Context, req any) error { - // req argument should be strengthen with type RequestMessage - s.statusSupported = isStatusSupported(req.(RequestMessage)) // panic is OK here for now - - var err error - - // verify request signatures - if err = signature.VerifyServiceMessage(req); err != nil { - err = fmt.Errorf("could not verify request: %w", err) - } else { - err = s.send(ctx, req) - } - - if err != nil { - if !s.statusSupported { - return err - } - - s.sendErr = err - - return ErrAbortStream - } - - return nil -} - -func (s *RequestMessageStreamer) CloseAndRecv(ctx context.Context) (ResponseMessage, error) { - var ( - resp ResponseMessage - err error - ) - - if s.sendErr != nil { - err = s.sendErr - } else { - resp, err = s.close(ctx) - if err != nil { - err = fmt.Errorf("could not close stream and receive response: %w", err) - } - } - - if err != nil { - if !s.statusSupported { - return nil, err - } - - resp = s.respCons() - - setStatusV2(resp, err) - } - - if err = signResponse(s.key, resp, s.statusSupported); err != nil { - return nil, err - } - - return resp, nil -} - -func (s *SignService) CreateRequestStreamer(sender RequestMessageWriter, closer ClientStreamCloser, blankResp ResponseConstructor) *RequestMessageStreamer { - return &RequestMessageStreamer{ - key: s.key, - send: sender, - close: closer, - - respCons: blankResp, - } -} - -func (s *SignService) HandleServerStreamRequest( - req any, - respWriter ResponseMessageWriter, - blankResp ResponseConstructor, - respWriterCaller func(ResponseMessageWriter) error, -) error { - // handle protocol versions <=2.10 (API statuses was introduced in 2.11 only) - - // req argument should be strengthen with type RequestMessage - statusSupported := isStatusSupported(req.(RequestMessage)) // panic is OK here for now - - var err error - - // verify request signatures - if err = signature.VerifyServiceMessage(req); err != nil { - err = fmt.Errorf("could not verify request: %w", err) - } else { - err = respWriterCaller(func(resp ResponseMessage) error { - if err := signResponse(s.key, resp, statusSupported); err != nil { - return err - } - - return respWriter(resp) - }) - } - - if err != nil { - if !statusSupported { - return err - } - - resp := blankResp() - - setStatusV2(resp, err) - - _ = signResponse(s.key, resp, false) // panics or returns nil with false arg - - return respWriter(resp) - } - - return nil -} - -func (s *SignService) SignResponse(req RequestMessage, resp ResponseMessage, err error) error { - // handle protocol versions <=2.10 (API statuses was introduced in 2.11 only) - - // req argument should be strengthen with type RequestMessage - statusSupported := isStatusSupported(req) - +// SignResponse response with private key via signature.SignServiceMessage. +// The signature error affects the result depending on the protocol version: +// - if status return is supported, panics since we cannot return the failed status, because it will not be signed. +// - otherwise, returns error in order to transport it directly. +func (s *SignService) SignResponse(statusSupported bool, resp ResponseMessage, err error) error { if err != nil { if !statusSupported { return err @@ -186,8 +45,18 @@ func (s *SignService) SignResponse(req RequestMessage, resp ResponseMessage, err setStatusV2(resp, err) } - // sign the response - return signResponse(s.key, resp, statusSupported) + err = signature.SignServiceMessage(s.key, resp) + if err != nil { + err = fmt.Errorf("could not sign response: %w", err) + + if statusSupported { + // We can't pass this error as status code since response will be unsigned. + // Isn't expected in practice, so panic is ok here. + panic(err) + } + } + + return err } func (s *SignService) VerifyRequest(req RequestMessage) error { @@ -207,7 +76,9 @@ func EnsureNonNilResponse[T any](resp *T, err error) (*T, error) { return new(T), err } -func isStatusSupported(req RequestMessage) bool { +// IsStatusSupported returns true iff request version implies expecting status return. +// This allows us to handle protocol versions <=2.10 (API statuses was introduced in 2.11 only). +func IsStatusSupported(req RequestMessage) bool { version := req.GetMetaHeader().GetVersion() mjr := version.GetMajor() @@ -223,22 +94,3 @@ func setStatusV2(resp ResponseMessage, err error) { session.SetStatus(resp, apistatus.ToStatusV2(apistatus.ErrToStatus(err))) } - -// signs response with private key via signature.SignServiceMessage. -// The signature error affects the result depending on the protocol version: -// - if status return is supported, panics since we cannot return the failed status, because it will not be signed; -// - otherwise, returns error in order to transport it directly. -func signResponse(key *ecdsa.PrivateKey, resp any, statusSupported bool) error { - err := signature.SignServiceMessage(key, resp) - if err != nil { - err = fmt.Errorf("could not sign response: %w", err) - - if statusSupported { - // We can't pass this error as status code since response will be unsigned. - // Isn't expected in practice, so panic is ok here. - panic(err) - } - } - - return err -}