[#533] services: Assume API supports status codes

Signed-off-by: Evgenii Stratonikov <e.stratonikov@yadro.com>
This commit is contained in:
Evgenii Stratonikov 2023-07-26 15:47:32 +03:00 committed by Evgenii Stratonikov
parent ec8b4fdc48
commit 7b0fdf0202
6 changed files with 47 additions and 70 deletions

View file

@ -23,5 +23,5 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server {
func (s *signService) Balance(ctx context.Context, req *accounting.BalanceRequest) (*accounting.BalanceResponse, error) { func (s *signService) Balance(ctx context.Context, req *accounting.BalanceRequest) (*accounting.BalanceResponse, error) {
resp, err := util.EnsureNonNilResponse(s.svc.Balance(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.Balance(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }

View file

@ -24,62 +24,62 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server {
func (s *signService) Put(ctx context.Context, req *container.PutRequest) (*container.PutResponse, error) { func (s *signService) Put(ctx context.Context, req *container.PutRequest) (*container.PutResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.PutResponse) resp := new(container.PutResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.Put(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.Put(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *signService) Delete(ctx context.Context, req *container.DeleteRequest) (*container.DeleteResponse, error) { func (s *signService) Delete(ctx context.Context, req *container.DeleteRequest) (*container.DeleteResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.DeleteResponse) resp := new(container.DeleteResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.Delete(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.Delete(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *signService) Get(ctx context.Context, req *container.GetRequest) (*container.GetResponse, error) { func (s *signService) Get(ctx context.Context, req *container.GetRequest) (*container.GetResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.GetResponse) resp := new(container.GetResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.Get(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.Get(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *signService) List(ctx context.Context, req *container.ListRequest) (*container.ListResponse, error) { func (s *signService) List(ctx context.Context, req *container.ListRequest) (*container.ListResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.ListResponse) resp := new(container.ListResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.List(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.List(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *signService) SetExtendedACL(ctx context.Context, req *container.SetExtendedACLRequest) (*container.SetExtendedACLResponse, error) { func (s *signService) SetExtendedACL(ctx context.Context, req *container.SetExtendedACLRequest) (*container.SetExtendedACLResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.SetExtendedACLResponse) resp := new(container.SetExtendedACLResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.SetExtendedACL(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.SetExtendedACL(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *signService) GetExtendedACL(ctx context.Context, req *container.GetExtendedACLRequest) (*container.GetExtendedACLResponse, error) { func (s *signService) GetExtendedACL(ctx context.Context, req *container.GetExtendedACLRequest) (*container.GetExtendedACLResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.GetExtendedACLResponse) resp := new(container.GetExtendedACLResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.GetExtendedACL(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.GetExtendedACL(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *signService) AnnounceUsedSpace(ctx context.Context, req *container.AnnounceUsedSpaceRequest) (*container.AnnounceUsedSpaceResponse, error) { func (s *signService) AnnounceUsedSpace(ctx context.Context, req *container.AnnounceUsedSpaceRequest) (*container.AnnounceUsedSpaceResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.AnnounceUsedSpaceResponse) resp := new(container.AnnounceUsedSpaceResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.AnnounceUsedSpace(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.AnnounceUsedSpace(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }

View file

@ -26,26 +26,26 @@ func (s *signService) LocalNodeInfo(
req *netmap.LocalNodeInfoRequest) (*netmap.LocalNodeInfoResponse, error) { req *netmap.LocalNodeInfoRequest) (*netmap.LocalNodeInfoResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(netmap.LocalNodeInfoResponse) resp := new(netmap.LocalNodeInfoResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.LocalNodeInfo(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.LocalNodeInfo(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *signService) NetworkInfo(ctx context.Context, req *netmap.NetworkInfoRequest) (*netmap.NetworkInfoResponse, error) { func (s *signService) NetworkInfo(ctx context.Context, req *netmap.NetworkInfoRequest) (*netmap.NetworkInfoResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(netmap.NetworkInfoResponse) resp := new(netmap.NetworkInfoResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.NetworkInfo(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.NetworkInfo(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *signService) Snapshot(ctx context.Context, req *netmap.SnapshotRequest) (*netmap.SnapshotResponse, error) { func (s *signService) Snapshot(ctx context.Context, req *netmap.SnapshotRequest) (*netmap.SnapshotResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(netmap.SnapshotResponse) resp := new(netmap.SnapshotResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.Snapshot(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.Snapshot(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }

View file

@ -19,29 +19,25 @@ type SignService struct {
type searchStreamSigner struct { type searchStreamSigner struct {
SearchStream SearchStream
statusSupported bool sigSvc *util.SignService
sigSvc *util.SignService
nonEmptyResp bool // set on first Send call nonEmptyResp bool // set on first Send call
} }
type getStreamSigner struct { type getStreamSigner struct {
GetObjectStream GetObjectStream
statusSupported bool sigSvc *util.SignService
sigSvc *util.SignService
} }
type putStreamSigner struct { type putStreamSigner struct {
sigSvc *util.SignService sigSvc *util.SignService
stream PutObjectStream stream PutObjectStream
statusSupported bool err error
err error
} }
type getRangeStreamSigner struct { type getRangeStreamSigner struct {
GetObjectRangeStream GetObjectRangeStream
statusSupported bool sigSvc *util.SignService
sigSvc *util.SignService
} }
func NewSignService(key *ecdsa.PrivateKey, svc ServiceServer) *SignService { func NewSignService(key *ecdsa.PrivateKey, svc ServiceServer) *SignService {
@ -53,7 +49,7 @@ func NewSignService(key *ecdsa.PrivateKey, svc ServiceServer) *SignService {
} }
func (s *getStreamSigner) Send(resp *object.GetResponse) error { func (s *getStreamSigner) Send(resp *object.GetResponse) error {
if err := s.sigSvc.SignResponse(s.statusSupported, resp, nil); err != nil { if err := s.sigSvc.SignResponse(resp, nil); err != nil {
return err return err
} }
return s.GetObjectStream.Send(resp) return s.GetObjectStream.Send(resp)
@ -62,20 +58,17 @@ func (s *getStreamSigner) Send(resp *object.GetResponse) error {
func (s *SignService) Get(req *object.GetRequest, stream GetObjectStream) error { func (s *SignService) Get(req *object.GetRequest, stream GetObjectStream) error {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(object.GetResponse) resp := new(object.GetResponse)
_ = s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) _ = s.sigSvc.SignResponse(resp, err)
return stream.Send(resp) return stream.Send(resp)
} }
return s.svc.Get(req, &getStreamSigner{ return s.svc.Get(req, &getStreamSigner{
GetObjectStream: stream, GetObjectStream: stream,
sigSvc: s.sigSvc, sigSvc: s.sigSvc,
statusSupported: util.IsStatusSupported(req),
}) })
} }
func (s *putStreamSigner) Send(ctx context.Context, req *object.PutRequest) error { func (s *putStreamSigner) Send(ctx context.Context, req *object.PutRequest) error {
s.statusSupported = util.IsStatusSupported(req)
if s.err = s.sigSvc.VerifyRequest(req); s.err != nil { if s.err = s.sigSvc.VerifyRequest(req); s.err != nil {
return util.ErrAbortStream return util.ErrAbortStream
} }
@ -96,7 +89,7 @@ func (s *putStreamSigner) CloseAndRecv(ctx context.Context) (resp *object.PutRes
} }
} }
return resp, s.sigSvc.SignResponse(s.statusSupported, resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *SignService) Put() (PutObjectStream, error) { func (s *SignService) Put() (PutObjectStream, error) {
@ -114,24 +107,24 @@ func (s *SignService) Put() (PutObjectStream, error) {
func (s *SignService) Head(ctx context.Context, req *object.HeadRequest) (*object.HeadResponse, error) { func (s *SignService) Head(ctx context.Context, req *object.HeadRequest) (*object.HeadResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(object.HeadResponse) resp := new(object.HeadResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.Head(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.Head(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *SignService) PutSingle(ctx context.Context, req *object.PutSingleRequest) (*object.PutSingleResponse, error) { func (s *SignService) PutSingle(ctx context.Context, req *object.PutSingleRequest) (*object.PutSingleResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(object.PutSingleResponse) resp := new(object.PutSingleResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.PutSingle(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.PutSingle(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *searchStreamSigner) Send(resp *object.SearchResponse) error { func (s *searchStreamSigner) Send(resp *object.SearchResponse) error {
s.nonEmptyResp = true s.nonEmptyResp = true
if err := s.sigSvc.SignResponse(s.statusSupported, resp, nil); err != nil { if err := s.sigSvc.SignResponse(resp, nil); err != nil {
return err return err
} }
return s.SearchStream.Send(resp) return s.SearchStream.Send(resp)
@ -140,14 +133,13 @@ func (s *searchStreamSigner) Send(resp *object.SearchResponse) error {
func (s *SignService) Search(req *object.SearchRequest, stream SearchStream) error { func (s *SignService) Search(req *object.SearchRequest, stream SearchStream) error {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(object.SearchResponse) resp := new(object.SearchResponse)
_ = s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) _ = s.sigSvc.SignResponse(resp, err)
return stream.Send(resp) return stream.Send(resp)
} }
ss := &searchStreamSigner{ ss := &searchStreamSigner{
SearchStream: stream, SearchStream: stream,
sigSvc: s.sigSvc, sigSvc: s.sigSvc,
statusSupported: util.IsStatusSupported(req),
} }
err := s.svc.Search(req, ss) err := s.svc.Search(req, ss)
if err == nil && !ss.nonEmptyResp { if err == nil && !ss.nonEmptyResp {
@ -164,14 +156,14 @@ func (s *SignService) Search(req *object.SearchRequest, stream SearchStream) err
func (s *SignService) Delete(ctx context.Context, req *object.DeleteRequest) (*object.DeleteResponse, error) { func (s *SignService) Delete(ctx context.Context, req *object.DeleteRequest) (*object.DeleteResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(object.DeleteResponse) resp := new(object.DeleteResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.Delete(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.Delete(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
func (s *getRangeStreamSigner) Send(resp *object.GetRangeResponse) error { func (s *getRangeStreamSigner) Send(resp *object.GetRangeResponse) error {
if err := s.sigSvc.SignResponse(s.statusSupported, resp, nil); err != nil { if err := s.sigSvc.SignResponse(resp, nil); err != nil {
return err return err
} }
return s.GetObjectRangeStream.Send(resp) return s.GetObjectRangeStream.Send(resp)
@ -180,22 +172,21 @@ func (s *getRangeStreamSigner) Send(resp *object.GetRangeResponse) error {
func (s *SignService) GetRange(req *object.GetRangeRequest, stream GetObjectRangeStream) error { func (s *SignService) GetRange(req *object.GetRangeRequest, stream GetObjectRangeStream) error {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(object.GetRangeResponse) resp := new(object.GetRangeResponse)
_ = s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) _ = s.sigSvc.SignResponse(resp, err)
return stream.Send(resp) return stream.Send(resp)
} }
return s.svc.GetRange(req, &getRangeStreamSigner{ return s.svc.GetRange(req, &getRangeStreamSigner{
GetObjectRangeStream: stream, GetObjectRangeStream: stream,
sigSvc: s.sigSvc, sigSvc: s.sigSvc,
statusSupported: util.IsStatusSupported(req),
}) })
} }
func (s *SignService) GetRangeHash(ctx context.Context, req *object.GetRangeHashRequest) (*object.GetRangeHashResponse, error) { func (s *SignService) GetRangeHash(ctx context.Context, req *object.GetRangeHashRequest) (*object.GetRangeHashResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(object.GetRangeHashResponse) resp := new(object.GetRangeHashResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.GetRangeHash(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.GetRangeHash(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }

View file

@ -24,8 +24,8 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server {
func (s *signService) Create(ctx context.Context, req *session.CreateRequest) (*session.CreateResponse, error) { func (s *signService) Create(ctx context.Context, req *session.CreateRequest) (*session.CreateResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(session.CreateResponse) resp := new(session.CreateResponse)
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }
resp, err := util.EnsureNonNilResponse(s.svc.Create(ctx, req)) resp, err := util.EnsureNonNilResponse(s.svc.Create(ctx, req))
return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err) return resp, s.sigSvc.SignResponse(resp, err)
} }

View file

@ -36,12 +36,8 @@ func NewUnarySignService(key *ecdsa.PrivateKey) *SignService {
// The signature error affects the result depending on the protocol version: // The signature error affects the result depending on the protocol version:
// - if status return is supported, panics since we cannot return the failed status, because it will not be signed. // - if status return is supported, panics since we cannot return the failed status, because it will not be signed.
// - otherwise, returns error in order to transport it directly. // - otherwise, returns error in order to transport it directly.
func (s *SignService) SignResponse(statusSupported bool, resp ResponseMessage, err error) error { func (s *SignService) SignResponse(resp ResponseMessage, err error) error {
if err != nil { if err != nil {
if !statusSupported {
return err
}
setStatusV2(resp, err) setStatusV2(resp, err)
} }
@ -70,16 +66,6 @@ func EnsureNonNilResponse[T any](resp *T, err error) (*T, error) {
return new(T), err return new(T), err
} }
// IsStatusSupported returns true iff request version implies expecting status return.
// This allows us to handle protocol versions <=2.10 (API statuses was introduced in 2.11 only).
func IsStatusSupported(req RequestMessage) bool {
version := req.GetMetaHeader().GetVersion()
mjr := version.GetMajor()
return mjr > 2 || mjr == 2 && version.GetMinor() >= 11
}
func setStatusV2(resp ResponseMessage, err error) { func setStatusV2(resp ResponseMessage, err error) {
// unwrap error // unwrap error
for e := errors.Unwrap(err); e != nil; e = errors.Unwrap(err) { for e := errors.Unwrap(err); e != nil; e = errors.Unwrap(err) {