diff --git a/pkg/services/accounting/sign.go b/pkg/services/accounting/sign.go index e98d9b3a..9efb063f 100644 --- a/pkg/services/accounting/sign.go +++ b/pkg/services/accounting/sign.go @@ -22,17 +22,6 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server { } func (s *signService) Balance(ctx context.Context, req *accounting.BalanceRequest) (*accounting.BalanceResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.Balance(ctx, req.(*accounting.BalanceRequest)) - }, - func() util.ResponseMessage { - return new(accounting.BalanceResponse) - }, - ) - if err != nil { - return nil, err - } - - return resp.(*accounting.BalanceResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.Balance(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } diff --git a/pkg/services/container/sign.go b/pkg/services/container/sign.go index 9e77e2e2..55125335 100644 --- a/pkg/services/container/sign.go +++ b/pkg/services/container/sign.go @@ -22,113 +22,64 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server { } func (s *signService) Put(ctx context.Context, req *container.PutRequest) (*container.PutResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.Put(ctx, req.(*container.PutRequest)) - }, - func() util.ResponseMessage { - return new(container.PutResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(container.PutResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*container.PutResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.Put(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } func (s *signService) Delete(ctx context.Context, req *container.DeleteRequest) (*container.DeleteResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.Delete(ctx, req.(*container.DeleteRequest)) - }, - func() util.ResponseMessage { - return new(container.DeleteResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(container.DeleteResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*container.DeleteResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.Delete(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } func (s *signService) Get(ctx context.Context, req *container.GetRequest) (*container.GetResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.Get(ctx, req.(*container.GetRequest)) - }, - func() util.ResponseMessage { - return new(container.GetResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(container.GetResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*container.GetResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.Get(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } func (s *signService) List(ctx context.Context, req *container.ListRequest) (*container.ListResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.List(ctx, req.(*container.ListRequest)) - }, - func() util.ResponseMessage { - return new(container.ListResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(container.ListResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*container.ListResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.List(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } func (s *signService) SetExtendedACL(ctx context.Context, req *container.SetExtendedACLRequest) (*container.SetExtendedACLResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.SetExtendedACL(ctx, req.(*container.SetExtendedACLRequest)) - }, - func() util.ResponseMessage { - return new(container.SetExtendedACLResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(container.SetExtendedACLResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*container.SetExtendedACLResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.SetExtendedACL(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } func (s *signService) GetExtendedACL(ctx context.Context, req *container.GetExtendedACLRequest) (*container.GetExtendedACLResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.GetExtendedACL(ctx, req.(*container.GetExtendedACLRequest)) - }, - func() util.ResponseMessage { - return new(container.GetExtendedACLResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(container.GetExtendedACLResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*container.GetExtendedACLResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.GetExtendedACL(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } func (s *signService) AnnounceUsedSpace(ctx context.Context, req *container.AnnounceUsedSpaceRequest) (*container.AnnounceUsedSpaceResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.AnnounceUsedSpace(ctx, req.(*container.AnnounceUsedSpaceRequest)) - }, - func() util.ResponseMessage { - return new(container.AnnounceUsedSpaceResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(container.AnnounceUsedSpaceResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*container.AnnounceUsedSpaceResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.AnnounceUsedSpace(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } diff --git a/pkg/services/netmap/sign.go b/pkg/services/netmap/sign.go index 85b19d86..e665519e 100644 --- a/pkg/services/netmap/sign.go +++ b/pkg/services/netmap/sign.go @@ -24,49 +24,28 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server { func (s *signService) LocalNodeInfo( ctx context.Context, req *netmap.LocalNodeInfoRequest) (*netmap.LocalNodeInfoResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.LocalNodeInfo(ctx, req.(*netmap.LocalNodeInfoRequest)) - }, - func() util.ResponseMessage { - return new(netmap.LocalNodeInfoResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(netmap.LocalNodeInfoResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*netmap.LocalNodeInfoResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.LocalNodeInfo(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } func (s *signService) NetworkInfo(ctx context.Context, req *netmap.NetworkInfoRequest) (*netmap.NetworkInfoResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.NetworkInfo(ctx, req.(*netmap.NetworkInfoRequest)) - }, - func() util.ResponseMessage { - return new(netmap.NetworkInfoResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(netmap.NetworkInfoResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*netmap.NetworkInfoResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.NetworkInfo(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } func (s *signService) Snapshot(ctx context.Context, req *netmap.SnapshotRequest) (*netmap.SnapshotResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.Snapshot(ctx, req.(*netmap.SnapshotRequest)) - }, - func() util.ResponseMessage { - return new(netmap.SnapshotResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(netmap.SnapshotResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*netmap.SnapshotResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.Snapshot(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } diff --git a/pkg/services/object/sign.go b/pkg/services/object/sign.go index 5b3578e2..f160b4c5 100644 --- a/pkg/services/object/sign.go +++ b/pkg/services/object/sign.go @@ -105,35 +105,21 @@ func (s *SignService) Put() (PutObjectStream, error) { } func (s *SignService) Head(ctx context.Context, req *object.HeadRequest) (*object.HeadResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.Head(ctx, req.(*object.HeadRequest)) - }, - func() util.ResponseMessage { - return new(object.HeadResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(object.HeadResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*object.HeadResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.Head(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } func (s *SignService) PutSingle(ctx context.Context, req *object.PutSingleRequest) (*object.PutSingleResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.PutSingle(ctx, req.(*object.PutSingleRequest)) - }, - func() util.ResponseMessage { - return new(object.PutSingleResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(object.PutSingleResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*object.PutSingleResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.PutSingle(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } func (s *searchStreamSigner) Send(resp *object.SearchResponse) error { @@ -172,19 +158,12 @@ func (s *SignService) Search(req *object.SearchRequest, stream SearchStream) err } func (s *SignService) Delete(ctx context.Context, req *object.DeleteRequest) (*object.DeleteResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.Delete(ctx, req.(*object.DeleteRequest)) - }, - func() util.ResponseMessage { - return new(object.DeleteResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(object.DeleteResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*object.DeleteResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.Delete(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } func (s *getRangeStreamSigner) Send(resp *object.GetRangeResponse) error { @@ -209,17 +188,10 @@ func (s *SignService) GetRange(req *object.GetRangeRequest, stream GetObjectRang } func (s *SignService) GetRangeHash(ctx context.Context, req *object.GetRangeHashRequest) (*object.GetRangeHashResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.GetRangeHash(ctx, req.(*object.GetRangeHashRequest)) - }, - func() util.ResponseMessage { - return new(object.GetRangeHashResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(object.GetRangeHashResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*object.GetRangeHashResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.GetRangeHash(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } diff --git a/pkg/services/session/sign.go b/pkg/services/session/sign.go index 1156dc53..33c0a531 100644 --- a/pkg/services/session/sign.go +++ b/pkg/services/session/sign.go @@ -22,17 +22,10 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server { } func (s *signService) Create(ctx context.Context, req *session.CreateRequest) (*session.CreateResponse, error) { - resp, err := s.sigSvc.HandleUnaryRequest(ctx, req, - func(ctx context.Context, req any) (util.ResponseMessage, error) { - return s.svc.Create(ctx, req.(*session.CreateRequest)) - }, - func() util.ResponseMessage { - return new(session.CreateResponse) - }, - ) - if err != nil { - return nil, err + if err := s.sigSvc.VerifyRequest(req); err != nil { + resp := new(session.CreateResponse) + return resp, s.sigSvc.SignResponse(req, resp, err) } - - return resp.(*session.CreateResponse), nil + resp, err := util.EnsureNonNilResponse(s.svc.Create(ctx, req)) + return resp, s.sigSvc.SignResponse(req, resp, err) } diff --git a/pkg/services/util/sign.go b/pkg/services/util/sign.go index cb4be308..e4157566 100644 --- a/pkg/services/util/sign.go +++ b/pkg/services/util/sign.go @@ -172,44 +172,39 @@ func (s *SignService) HandleServerStreamRequest( return nil } -func (s *SignService) HandleUnaryRequest(ctx context.Context, req any, handler UnaryHandler, blankResp ResponseConstructor) (ResponseMessage, error) { +func (s *SignService) SignResponse(req RequestMessage, resp ResponseMessage, err error) error { // handle protocol versions <=2.10 (API statuses was introduced in 2.11 only) // req argument should be strengthen with type RequestMessage - statusSupported := isStatusSupported(req.(RequestMessage)) // panic is OK here for now - - var ( - resp ResponseMessage - err error - ) - - // verify request signatures - if err = signature.VerifyServiceMessage(req); err != nil { - var sigErr apistatus.SignatureVerification - sigErr.SetMessage(err.Error()) - - err = sigErr - } else { - // process request - resp, err = handler(ctx, req) - } + statusSupported := isStatusSupported(req) if err != nil { if !statusSupported { - return nil, err + return err } - resp = blankResp() - setStatusV2(resp, err) } // sign the response - if err = signResponse(s.key, resp, statusSupported); err != nil { - return nil, err - } + return signResponse(s.key, resp, statusSupported) +} - return resp, nil +func (s *SignService) VerifyRequest(req RequestMessage) error { + if err := signature.VerifyServiceMessage(req); err != nil { + var sigErr apistatus.SignatureVerification + sigErr.SetMessage(err.Error()) + return sigErr + } + return nil +} + +// EnsureNonNilResponse creates an appropriate response struct if it is nil. +func EnsureNonNilResponse[T any](resp *T, err error) (*T, error) { + if resp != nil { + return resp, err + } + return new(T), err } func isStatusSupported(req RequestMessage) bool {