From 372160d04807b12ffddacaa8ed949b93b29f98be Mon Sep 17 00:00:00 2001
From: Evgenii Stratonikov <e.stratonikov@yadro.com>
Date: Fri, 30 Dec 2022 18:26:43 +0300
Subject: [PATCH] [#6] services/util: Remove `SignService.HandleUnaryRequest`

There is no need in a wrapper with many from-`interface{}` conversions.

Signed-off-by: Evgenii Stratonikov <e.stratonikov@yadro.com>
---
 pkg/services/accounting/sign.go |  15 +---
 pkg/services/container/sign.go  | 119 ++++++++++----------------------
 pkg/services/netmap/sign.go     |  51 ++++----------
 pkg/services/object/sign.go     |  68 ++++++------------
 pkg/services/session/sign.go    |  17 ++---
 pkg/services/util/sign.go       |  45 ++++++------
 6 files changed, 97 insertions(+), 218 deletions(-)

diff --git a/pkg/services/accounting/sign.go b/pkg/services/accounting/sign.go
index e98d9b3af..9efb063f5 100644
--- a/pkg/services/accounting/sign.go
+++ b/pkg/services/accounting/sign.go
@@ -22,17 +22,6 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server {
 }
 
 func (s *signService) Balance(ctx context.Context, req *accounting.BalanceRequest) (*accounting.BalanceResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.Balance(ctx, req.(*accounting.BalanceRequest))
-		},
-		func() util.ResponseMessage {
-			return new(accounting.BalanceResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
-	}
-
-	return resp.(*accounting.BalanceResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.Balance(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
diff --git a/pkg/services/container/sign.go b/pkg/services/container/sign.go
index 9e77e2e21..55125335f 100644
--- a/pkg/services/container/sign.go
+++ b/pkg/services/container/sign.go
@@ -22,113 +22,64 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server {
 }
 
 func (s *signService) Put(ctx context.Context, req *container.PutRequest) (*container.PutResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.Put(ctx, req.(*container.PutRequest))
-		},
-		func() util.ResponseMessage {
-			return new(container.PutResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(container.PutResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*container.PutResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.Put(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
 
 func (s *signService) Delete(ctx context.Context, req *container.DeleteRequest) (*container.DeleteResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.Delete(ctx, req.(*container.DeleteRequest))
-		},
-		func() util.ResponseMessage {
-			return new(container.DeleteResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(container.DeleteResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*container.DeleteResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.Delete(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
 
 func (s *signService) Get(ctx context.Context, req *container.GetRequest) (*container.GetResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.Get(ctx, req.(*container.GetRequest))
-		},
-		func() util.ResponseMessage {
-			return new(container.GetResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(container.GetResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*container.GetResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.Get(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
 
 func (s *signService) List(ctx context.Context, req *container.ListRequest) (*container.ListResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.List(ctx, req.(*container.ListRequest))
-		},
-		func() util.ResponseMessage {
-			return new(container.ListResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(container.ListResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*container.ListResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.List(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
 
 func (s *signService) SetExtendedACL(ctx context.Context, req *container.SetExtendedACLRequest) (*container.SetExtendedACLResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.SetExtendedACL(ctx, req.(*container.SetExtendedACLRequest))
-		},
-		func() util.ResponseMessage {
-			return new(container.SetExtendedACLResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(container.SetExtendedACLResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*container.SetExtendedACLResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.SetExtendedACL(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
 
 func (s *signService) GetExtendedACL(ctx context.Context, req *container.GetExtendedACLRequest) (*container.GetExtendedACLResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.GetExtendedACL(ctx, req.(*container.GetExtendedACLRequest))
-		},
-		func() util.ResponseMessage {
-			return new(container.GetExtendedACLResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(container.GetExtendedACLResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*container.GetExtendedACLResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.GetExtendedACL(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
 
 func (s *signService) AnnounceUsedSpace(ctx context.Context, req *container.AnnounceUsedSpaceRequest) (*container.AnnounceUsedSpaceResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.AnnounceUsedSpace(ctx, req.(*container.AnnounceUsedSpaceRequest))
-		},
-		func() util.ResponseMessage {
-			return new(container.AnnounceUsedSpaceResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(container.AnnounceUsedSpaceResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*container.AnnounceUsedSpaceResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.AnnounceUsedSpace(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
diff --git a/pkg/services/netmap/sign.go b/pkg/services/netmap/sign.go
index 85b19d862..e665519e1 100644
--- a/pkg/services/netmap/sign.go
+++ b/pkg/services/netmap/sign.go
@@ -24,49 +24,28 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server {
 func (s *signService) LocalNodeInfo(
 	ctx context.Context,
 	req *netmap.LocalNodeInfoRequest) (*netmap.LocalNodeInfoResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.LocalNodeInfo(ctx, req.(*netmap.LocalNodeInfoRequest))
-		},
-		func() util.ResponseMessage {
-			return new(netmap.LocalNodeInfoResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(netmap.LocalNodeInfoResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*netmap.LocalNodeInfoResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.LocalNodeInfo(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
 
 func (s *signService) NetworkInfo(ctx context.Context, req *netmap.NetworkInfoRequest) (*netmap.NetworkInfoResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.NetworkInfo(ctx, req.(*netmap.NetworkInfoRequest))
-		},
-		func() util.ResponseMessage {
-			return new(netmap.NetworkInfoResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(netmap.NetworkInfoResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*netmap.NetworkInfoResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.NetworkInfo(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
 
 func (s *signService) Snapshot(ctx context.Context, req *netmap.SnapshotRequest) (*netmap.SnapshotResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.Snapshot(ctx, req.(*netmap.SnapshotRequest))
-		},
-		func() util.ResponseMessage {
-			return new(netmap.SnapshotResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(netmap.SnapshotResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*netmap.SnapshotResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.Snapshot(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
diff --git a/pkg/services/object/sign.go b/pkg/services/object/sign.go
index 5b3578e29..f160b4c56 100644
--- a/pkg/services/object/sign.go
+++ b/pkg/services/object/sign.go
@@ -105,35 +105,21 @@ func (s *SignService) Put() (PutObjectStream, error) {
 }
 
 func (s *SignService) Head(ctx context.Context, req *object.HeadRequest) (*object.HeadResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.Head(ctx, req.(*object.HeadRequest))
-		},
-		func() util.ResponseMessage {
-			return new(object.HeadResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(object.HeadResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*object.HeadResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.Head(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
 
 func (s *SignService) PutSingle(ctx context.Context, req *object.PutSingleRequest) (*object.PutSingleResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.PutSingle(ctx, req.(*object.PutSingleRequest))
-		},
-		func() util.ResponseMessage {
-			return new(object.PutSingleResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(object.PutSingleResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*object.PutSingleResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.PutSingle(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
 
 func (s *searchStreamSigner) Send(resp *object.SearchResponse) error {
@@ -172,19 +158,12 @@ func (s *SignService) Search(req *object.SearchRequest, stream SearchStream) err
 }
 
 func (s *SignService) Delete(ctx context.Context, req *object.DeleteRequest) (*object.DeleteResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.Delete(ctx, req.(*object.DeleteRequest))
-		},
-		func() util.ResponseMessage {
-			return new(object.DeleteResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(object.DeleteResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*object.DeleteResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.Delete(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
 
 func (s *getRangeStreamSigner) Send(resp *object.GetRangeResponse) error {
@@ -209,17 +188,10 @@ func (s *SignService) GetRange(req *object.GetRangeRequest, stream GetObjectRang
 }
 
 func (s *SignService) GetRangeHash(ctx context.Context, req *object.GetRangeHashRequest) (*object.GetRangeHashResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.GetRangeHash(ctx, req.(*object.GetRangeHashRequest))
-		},
-		func() util.ResponseMessage {
-			return new(object.GetRangeHashResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(object.GetRangeHashResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*object.GetRangeHashResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.GetRangeHash(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
diff --git a/pkg/services/session/sign.go b/pkg/services/session/sign.go
index 1156dc538..33c0a531c 100644
--- a/pkg/services/session/sign.go
+++ b/pkg/services/session/sign.go
@@ -22,17 +22,10 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server {
 }
 
 func (s *signService) Create(ctx context.Context, req *session.CreateRequest) (*session.CreateResponse, error) {
-	resp, err := s.sigSvc.HandleUnaryRequest(ctx, req,
-		func(ctx context.Context, req any) (util.ResponseMessage, error) {
-			return s.svc.Create(ctx, req.(*session.CreateRequest))
-		},
-		func() util.ResponseMessage {
-			return new(session.CreateResponse)
-		},
-	)
-	if err != nil {
-		return nil, err
+	if err := s.sigSvc.VerifyRequest(req); err != nil {
+		resp := new(session.CreateResponse)
+		return resp, s.sigSvc.SignResponse(req, resp, err)
 	}
-
-	return resp.(*session.CreateResponse), nil
+	resp, err := util.EnsureNonNilResponse(s.svc.Create(ctx, req))
+	return resp, s.sigSvc.SignResponse(req, resp, err)
 }
diff --git a/pkg/services/util/sign.go b/pkg/services/util/sign.go
index cb4be3084..e41575663 100644
--- a/pkg/services/util/sign.go
+++ b/pkg/services/util/sign.go
@@ -172,44 +172,39 @@ func (s *SignService) HandleServerStreamRequest(
 	return nil
 }
 
-func (s *SignService) HandleUnaryRequest(ctx context.Context, req any, handler UnaryHandler, blankResp ResponseConstructor) (ResponseMessage, error) {
+func (s *SignService) SignResponse(req RequestMessage, resp ResponseMessage, err error) error {
 	// handle protocol versions <=2.10 (API statuses was introduced in 2.11 only)
 
 	// req argument should be strengthen with type RequestMessage
-	statusSupported := isStatusSupported(req.(RequestMessage)) // panic is OK here for now
-
-	var (
-		resp ResponseMessage
-		err  error
-	)
-
-	// verify request signatures
-	if err = signature.VerifyServiceMessage(req); err != nil {
-		var sigErr apistatus.SignatureVerification
-		sigErr.SetMessage(err.Error())
-
-		err = sigErr
-	} else {
-		// process request
-		resp, err = handler(ctx, req)
-	}
+	statusSupported := isStatusSupported(req)
 
 	if err != nil {
 		if !statusSupported {
-			return nil, err
+			return err
 		}
 
-		resp = blankResp()
-
 		setStatusV2(resp, err)
 	}
 
 	// sign the response
-	if err = signResponse(s.key, resp, statusSupported); err != nil {
-		return nil, err
-	}
+	return signResponse(s.key, resp, statusSupported)
+}
 
-	return resp, nil
+func (s *SignService) VerifyRequest(req RequestMessage) error {
+	if err := signature.VerifyServiceMessage(req); err != nil {
+		var sigErr apistatus.SignatureVerification
+		sigErr.SetMessage(err.Error())
+		return sigErr
+	}
+	return nil
+}
+
+// EnsureNonNilResponse creates an appropriate response struct if it is nil.
+func EnsureNonNilResponse[T any](resp *T, err error) (*T, error) {
+	if resp != nil {
+		return resp, err
+	}
+	return new(T), err
 }
 
 func isStatusSupported(req RequestMessage) bool {