[#6] services/util: Remove remaining stream wrappers

Signed-off-by: Evgenii Stratonikov <e.stratonikov@yadro.com>
This commit is contained in:
Evgenii Stratonikov 2022-12-30 20:01:13 +03:00
parent dd23048ab3
commit 11b2cc867d
7 changed files with 142 additions and 287 deletions

View file

@ -23,5 +23,5 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server {
func (s *signService) Balance(ctx context.Context, req *accounting.BalanceRequest) (*accounting.BalanceResponse, error) { func (s *signService) Balance(ctx context.Context, req *accounting.BalanceRequest) (*accounting.BalanceResponse, error) {
resp, err := util.WrapResponse(s.svc.Balance(ctx, req)) resp, err := util.WrapResponse(s.svc.Balance(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }

View file

@ -24,62 +24,62 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server {
func (s *signService) Put(ctx context.Context, req *container.PutRequest) (*container.PutResponse, error) { func (s *signService) Put(ctx context.Context, req *container.PutRequest) (*container.PutResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.PutResponse) resp := new(container.PutResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.Put(ctx, req)) resp, err := util.WrapResponse(s.svc.Put(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
func (s *signService) Delete(ctx context.Context, req *container.DeleteRequest) (*container.DeleteResponse, error) { func (s *signService) Delete(ctx context.Context, req *container.DeleteRequest) (*container.DeleteResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.DeleteResponse) resp := new(container.DeleteResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.Delete(ctx, req)) resp, err := util.WrapResponse(s.svc.Delete(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
func (s *signService) Get(ctx context.Context, req *container.GetRequest) (*container.GetResponse, error) { func (s *signService) Get(ctx context.Context, req *container.GetRequest) (*container.GetResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.GetResponse) resp := new(container.GetResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.Get(ctx, req)) resp, err := util.WrapResponse(s.svc.Get(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
func (s *signService) List(ctx context.Context, req *container.ListRequest) (*container.ListResponse, error) { func (s *signService) List(ctx context.Context, req *container.ListRequest) (*container.ListResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.ListResponse) resp := new(container.ListResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.List(ctx, req)) resp, err := util.WrapResponse(s.svc.List(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
func (s *signService) SetExtendedACL(ctx context.Context, req *container.SetExtendedACLRequest) (*container.SetExtendedACLResponse, error) { func (s *signService) SetExtendedACL(ctx context.Context, req *container.SetExtendedACLRequest) (*container.SetExtendedACLResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.SetExtendedACLResponse) resp := new(container.SetExtendedACLResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.SetExtendedACL(ctx, req)) resp, err := util.WrapResponse(s.svc.SetExtendedACL(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
func (s *signService) GetExtendedACL(ctx context.Context, req *container.GetExtendedACLRequest) (*container.GetExtendedACLResponse, error) { func (s *signService) GetExtendedACL(ctx context.Context, req *container.GetExtendedACLRequest) (*container.GetExtendedACLResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.GetExtendedACLResponse) resp := new(container.GetExtendedACLResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.GetExtendedACL(ctx, req)) resp, err := util.WrapResponse(s.svc.GetExtendedACL(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
func (s *signService) AnnounceUsedSpace(ctx context.Context, req *container.AnnounceUsedSpaceRequest) (*container.AnnounceUsedSpaceResponse, error) { func (s *signService) AnnounceUsedSpace(ctx context.Context, req *container.AnnounceUsedSpaceRequest) (*container.AnnounceUsedSpaceResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(container.AnnounceUsedSpaceResponse) resp := new(container.AnnounceUsedSpaceResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.AnnounceUsedSpace(ctx, req)) resp, err := util.WrapResponse(s.svc.AnnounceUsedSpace(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }

View file

@ -26,26 +26,26 @@ func (s *signService) LocalNodeInfo(
req *netmap.LocalNodeInfoRequest) (*netmap.LocalNodeInfoResponse, error) { req *netmap.LocalNodeInfoRequest) (*netmap.LocalNodeInfoResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(netmap.LocalNodeInfoResponse) resp := new(netmap.LocalNodeInfoResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.LocalNodeInfo(ctx, req)) resp, err := util.WrapResponse(s.svc.LocalNodeInfo(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
func (s *signService) NetworkInfo(ctx context.Context, req *netmap.NetworkInfoRequest) (*netmap.NetworkInfoResponse, error) { func (s *signService) NetworkInfo(ctx context.Context, req *netmap.NetworkInfoRequest) (*netmap.NetworkInfoResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(netmap.NetworkInfoResponse) resp := new(netmap.NetworkInfoResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.NetworkInfo(ctx, req)) resp, err := util.WrapResponse(s.svc.NetworkInfo(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
func (s *signService) Snapshot(ctx context.Context, req *netmap.SnapshotRequest) (*netmap.SnapshotResponse, error) { func (s *signService) Snapshot(ctx context.Context, req *netmap.SnapshotRequest) (*netmap.SnapshotResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(netmap.SnapshotResponse) resp := new(netmap.SnapshotResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.Snapshot(ctx, req)) resp, err := util.WrapResponse(s.svc.Snapshot(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }

View file

@ -18,27 +18,29 @@ type SignService struct {
} }
type searchStreamSigner struct { type searchStreamSigner struct {
util.ServerStream SearchStream
statusSupported bool
respWriter util.ResponseMessageWriter sigSvc *util.SignService
nonEmptyResp bool // set on first Send call
nonEmptyResp bool // set on first Send call
} }
type getStreamSigner struct { type getStreamSigner struct {
util.ServerStream GetObjectStream
statusSupported bool
respWriter util.ResponseMessageWriter sigSvc *util.SignService
} }
type putStreamSigner struct { type putStreamSigner struct {
stream *util.RequestMessageStreamer sigSvc *util.SignService
stream PutObjectStream
statusSupported bool
err error
} }
type getRangeStreamSigner struct { type getRangeStreamSigner struct {
util.ServerStream GetObjectRangeStream
statusSupported bool
respWriter util.ResponseMessageWriter sigSvc *util.SignService
} }
func NewSignService(key *ecdsa.PrivateKey, svc ServiceServer) *SignService { func NewSignService(key *ecdsa.PrivateKey, svc ServiceServer) *SignService {
@ -50,37 +52,50 @@ func NewSignService(key *ecdsa.PrivateKey, svc ServiceServer) *SignService {
} }
func (s *getStreamSigner) Send(resp *object.GetResponse) error { func (s *getStreamSigner) Send(resp *object.GetResponse) error {
return s.respWriter(resp) if err := s.sigSvc.SignResponse(s.statusSupported, resp, nil); err != nil {
return err
}
return s.GetObjectStream.Send(resp)
} }
func (s *SignService) Get(req *object.GetRequest, stream GetObjectStream) error { func (s *SignService) Get(req *object.GetRequest, stream GetObjectStream) error {
return s.sigSvc.HandleServerStreamRequest(req, if err := s.sigSvc.VerifyRequest(req); err != nil {
func(resp util.ResponseMessage) error { resp := new(object.GetResponse)
return stream.Send(resp.(*object.GetResponse)) _ = s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
}, return stream.Send(resp)
func() util.ResponseMessage { }
return new(object.GetResponse)
}, return s.svc.Get(req, &getStreamSigner{
func(respWriter util.ResponseMessageWriter) error { GetObjectStream: stream,
return s.svc.Get(req, &getStreamSigner{ sigSvc: s.sigSvc,
ServerStream: stream, statusSupported: util.IsStatusSupported(req),
respWriter: respWriter, })
})
},
)
} }
func (s *putStreamSigner) Send(req *object.PutRequest) error { func (s *putStreamSigner) Send(req *object.PutRequest) error {
return s.stream.Send(req) s.statusSupported = util.IsStatusSupported(req)
if s.err = s.sigSvc.VerifyRequest(req); s.err != nil {
return util.ErrAbortStream
}
if s.err = s.stream.Send(req); s.err != nil {
return util.ErrAbortStream
}
return nil
} }
func (s *putStreamSigner) CloseAndRecv() (*object.PutResponse, error) { func (s *putStreamSigner) CloseAndRecv() (resp *object.PutResponse, err error) {
r, err := s.stream.CloseAndRecv() if s.err != nil {
if err != nil { err = s.err
return nil, fmt.Errorf("could not receive response: %w", err) resp = new(object.PutResponse)
} else {
resp, err = s.stream.CloseAndRecv()
if err != nil {
return nil, fmt.Errorf("could not close stream and receive response: %w", err)
}
} }
return r.(*object.PutResponse), nil return resp, s.sigSvc.SignResponse(s.statusSupported, resp, err)
} }
func (s *SignService) Put(ctx context.Context) (PutObjectStream, error) { func (s *SignService) Put(ctx context.Context) (PutObjectStream, error) {
@ -90,99 +105,87 @@ func (s *SignService) Put(ctx context.Context) (PutObjectStream, error) {
} }
return &putStreamSigner{ return &putStreamSigner{
stream: s.sigSvc.CreateRequestStreamer( stream: stream,
func(req any) error { sigSvc: s.sigSvc,
return stream.Send(req.(*object.PutRequest))
},
func() (util.ResponseMessage, error) {
return stream.CloseAndRecv()
},
func() util.ResponseMessage {
return new(object.PutResponse)
},
),
}, nil }, nil
} }
func (s *SignService) Head(ctx context.Context, req *object.HeadRequest) (*object.HeadResponse, error) { func (s *SignService) Head(ctx context.Context, req *object.HeadRequest) (*object.HeadResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(object.HeadResponse) resp := new(object.HeadResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.Head(ctx, req)) resp, err := util.WrapResponse(s.svc.Head(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
func (s *searchStreamSigner) Send(resp *object.SearchResponse) error { func (s *searchStreamSigner) Send(resp *object.SearchResponse) error {
s.nonEmptyResp = true s.nonEmptyResp = true
return s.respWriter(resp) if err := s.sigSvc.SignResponse(s.statusSupported, resp, nil); err != nil {
return err
}
return s.SearchStream.Send(resp)
} }
func (s *SignService) Search(req *object.SearchRequest, stream SearchStream) error { func (s *SignService) Search(req *object.SearchRequest, stream SearchStream) error {
return s.sigSvc.HandleServerStreamRequest(req, if err := s.sigSvc.VerifyRequest(req); err != nil {
func(resp util.ResponseMessage) error { resp := new(object.SearchResponse)
return stream.Send(resp.(*object.SearchResponse)) _ = s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
}, return stream.Send(resp)
func() util.ResponseMessage { }
return new(object.SearchResponse)
},
func(respWriter util.ResponseMessageWriter) error {
stream := &searchStreamSigner{
ServerStream: stream,
respWriter: respWriter,
}
err := s.svc.Search(req, stream) ss := &searchStreamSigner{
SearchStream: stream,
if err == nil && !stream.nonEmptyResp { sigSvc: s.sigSvc,
// The higher component does not write any response in the case of an empty result (which is correct). statusSupported: util.IsStatusSupported(req),
// With the introduction of status returns at least one answer must be signed and sent to the client. }
// This approach is supported by clients who do not know how to work with statuses (one could make err := s.svc.Search(req, ss)
// a switch according to the protocol version from the request, but the costs of sending an empty if err == nil && !ss.nonEmptyResp {
// answer can be neglected due to the gradual refusal to use the "old" clients). // The higher component does not write any response in the case of an empty result (which is correct).
return stream.Send(new(object.SearchResponse)) // With the introduction of status returns at least one answer must be signed and sent to the client.
} // This approach is supported by clients who do not know how to work with statuses (one could make
// a switch according to the protocol version from the request, but the costs of sending an empty
return err // answer can be neglected due to the gradual refusal to use the "old" clients).
}, return stream.Send(new(object.SearchResponse))
) }
return err
} }
func (s *SignService) Delete(ctx context.Context, req *object.DeleteRequest) (*object.DeleteResponse, error) { func (s *SignService) Delete(ctx context.Context, req *object.DeleteRequest) (*object.DeleteResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(object.DeleteResponse) resp := new(object.DeleteResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.Delete(ctx, req)) resp, err := util.WrapResponse(s.svc.Delete(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
func (s *getRangeStreamSigner) Send(resp *object.GetRangeResponse) error { func (s *getRangeStreamSigner) Send(resp *object.GetRangeResponse) error {
return s.respWriter(resp) if err := s.sigSvc.SignResponse(s.statusSupported, resp, nil); err != nil {
return err
}
return s.GetObjectRangeStream.Send(resp)
} }
func (s *SignService) GetRange(req *object.GetRangeRequest, stream GetObjectRangeStream) error { func (s *SignService) GetRange(req *object.GetRangeRequest, stream GetObjectRangeStream) error {
return s.sigSvc.HandleServerStreamRequest(req, if err := s.sigSvc.VerifyRequest(req); err != nil {
func(resp util.ResponseMessage) error { resp := new(object.GetRangeResponse)
return stream.Send(resp.(*object.GetRangeResponse)) _ = s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
}, return stream.Send(resp)
func() util.ResponseMessage { }
return new(object.GetRangeResponse)
}, return s.svc.GetRange(req, &getRangeStreamSigner{
func(respWriter util.ResponseMessageWriter) error { GetObjectRangeStream: stream,
return s.svc.GetRange(req, &getRangeStreamSigner{ sigSvc: s.sigSvc,
ServerStream: stream, statusSupported: util.IsStatusSupported(req),
respWriter: respWriter, })
})
},
)
} }
func (s *SignService) GetRangeHash(ctx context.Context, req *object.GetRangeHashRequest) (*object.GetRangeHashResponse, error) { func (s *SignService) GetRangeHash(ctx context.Context, req *object.GetRangeHashRequest) (*object.GetRangeHashResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(object.GetRangeHashResponse) resp := new(object.GetRangeHashResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.GetRangeHash(ctx, req)) resp, err := util.WrapResponse(s.svc.GetRangeHash(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }

View file

@ -24,17 +24,17 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server {
func (s *signService) AnnounceLocalTrust(ctx context.Context, req *reputation.AnnounceLocalTrustRequest) (*reputation.AnnounceLocalTrustResponse, error) { func (s *signService) AnnounceLocalTrust(ctx context.Context, req *reputation.AnnounceLocalTrustRequest) (*reputation.AnnounceLocalTrustResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(reputation.AnnounceLocalTrustResponse) resp := new(reputation.AnnounceLocalTrustResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.AnnounceLocalTrust(ctx, req)) resp, err := util.WrapResponse(s.svc.AnnounceLocalTrust(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
func (s *signService) AnnounceIntermediateResult(ctx context.Context, req *reputation.AnnounceIntermediateResultRequest) (*reputation.AnnounceIntermediateResultResponse, error) { func (s *signService) AnnounceIntermediateResult(ctx context.Context, req *reputation.AnnounceIntermediateResultRequest) (*reputation.AnnounceIntermediateResultResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(reputation.AnnounceIntermediateResultResponse) resp := new(reputation.AnnounceIntermediateResultResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.AnnounceIntermediateResult(ctx, req)) resp, err := util.WrapResponse(s.svc.AnnounceIntermediateResult(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }

View file

@ -24,8 +24,8 @@ func NewSignService(key *ecdsa.PrivateKey, svc Server) Server {
func (s *signService) Create(ctx context.Context, req *session.CreateRequest) (*session.CreateResponse, error) { func (s *signService) Create(ctx context.Context, req *session.CreateRequest) (*session.CreateResponse, error) {
if err := s.sigSvc.VerifyRequest(req); err != nil { if err := s.sigSvc.VerifyRequest(req); err != nil {
resp := new(session.CreateResponse) resp := new(session.CreateResponse)
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }
resp, err := util.WrapResponse(s.svc.Create(ctx, req)) resp, err := util.WrapResponse(s.svc.Create(ctx, req))
return resp, s.sigSvc.SignResponse(req, resp, err) return resp, s.sigSvc.SignResponse(util.IsStatusSupported(req), resp, err)
} }

View file

@ -1,7 +1,6 @@
package util package util
import ( import (
"context"
"crypto/ecdsa" "crypto/ecdsa"
"errors" "errors"
"fmt" "fmt"
@ -21,163 +20,23 @@ type ResponseMessage interface {
SetMetaHeader(*session.ResponseMetaHeader) SetMetaHeader(*session.ResponseMetaHeader)
} }
type UnaryHandler func(context.Context, any) (ResponseMessage, error)
type SignService struct { type SignService struct {
key *ecdsa.PrivateKey key *ecdsa.PrivateKey
} }
type ResponseMessageWriter func(ResponseMessage) error
type ServerStreamHandler func(context.Context, any) (ResponseMessageReader, error)
type ResponseMessageReader func() (ResponseMessage, error)
var ErrAbortStream = errors.New("abort message stream") var ErrAbortStream = errors.New("abort message stream")
type ResponseConstructor func() ResponseMessage
type RequestMessageWriter func(any) error
type ClientStreamCloser func() (ResponseMessage, error)
type RequestMessageStreamer struct {
key *ecdsa.PrivateKey
send RequestMessageWriter
close ClientStreamCloser
respCons ResponseConstructor
statusSupported bool
sendErr error
}
func NewUnarySignService(key *ecdsa.PrivateKey) *SignService { func NewUnarySignService(key *ecdsa.PrivateKey) *SignService {
return &SignService{ return &SignService{
key: key, key: key,
} }
} }
func (s *RequestMessageStreamer) Send(req any) error { // SignResponse response with private key via signature.SignServiceMessage.
// req argument should be strengthen with type RequestMessage // The signature error affects the result depending on the protocol version:
s.statusSupported = isStatusSupported(req.(RequestMessage)) // panic is OK here for now // - if status return is supported, panics since we cannot return the failed status, because it will not be signed.
// - otherwise, returns error in order to transport it directly.
var err error func (s *SignService) SignResponse(statusSupported bool, resp ResponseMessage, err error) error {
// verify request signatures
if err = signature.VerifyServiceMessage(req); err != nil {
err = fmt.Errorf("could not verify request: %w", err)
} else {
err = s.send(req)
}
if err != nil {
if !s.statusSupported {
return err
}
s.sendErr = err
return ErrAbortStream
}
return nil
}
func (s *RequestMessageStreamer) CloseAndRecv() (ResponseMessage, error) {
var (
resp ResponseMessage
err error
)
if s.sendErr != nil {
err = s.sendErr
} else {
resp, err = s.close()
if err != nil {
err = fmt.Errorf("could not close stream and receive response: %w", err)
}
}
if err != nil {
if !s.statusSupported {
return nil, err
}
resp = s.respCons()
setStatusV2(resp, err)
}
if err = signResponse(s.key, resp, s.statusSupported); err != nil {
return nil, err
}
return resp, nil
}
func (s *SignService) CreateRequestStreamer(sender RequestMessageWriter, closer ClientStreamCloser, blankResp ResponseConstructor) *RequestMessageStreamer {
return &RequestMessageStreamer{
key: s.key,
send: sender,
close: closer,
respCons: blankResp,
}
}
func (s *SignService) HandleServerStreamRequest(
req any,
respWriter ResponseMessageWriter,
blankResp ResponseConstructor,
respWriterCaller func(ResponseMessageWriter) error,
) error {
// handle protocol versions <=2.10 (API statuses was introduced in 2.11 only)
// req argument should be strengthen with type RequestMessage
statusSupported := isStatusSupported(req.(RequestMessage)) // panic is OK here for now
var err error
// verify request signatures
if err = signature.VerifyServiceMessage(req); err != nil {
err = fmt.Errorf("could not verify request: %w", err)
} else {
err = respWriterCaller(func(resp ResponseMessage) error {
if err := signResponse(s.key, resp, statusSupported); err != nil {
return err
}
return respWriter(resp)
})
}
if err != nil {
if !statusSupported {
return err
}
resp := blankResp()
setStatusV2(resp, err)
_ = signResponse(s.key, resp, false) // panics or returns nil with false arg
return respWriter(resp)
}
return nil
}
func (s *SignService) SignResponse(req RequestMessage, resp ResponseMessage, err error) error {
// handle protocol versions <=2.10 (API statuses was introduced in 2.11 only)
// req argument should be strengthen with type RequestMessage
statusSupported := isStatusSupported(req)
if err != nil { if err != nil {
if !statusSupported { if !statusSupported {
return err return err
@ -186,8 +45,18 @@ func (s *SignService) SignResponse(req RequestMessage, resp ResponseMessage, err
setStatusV2(resp, err) setStatusV2(resp, err)
} }
// sign the response err = signature.SignServiceMessage(s.key, resp)
return signResponse(s.key, resp, statusSupported) if err != nil {
err = fmt.Errorf("could not sign response: %w", err)
if statusSupported {
// We can't pass this error as status code since response will be unsigned.
// Isn't expected in practice, so panic is ok here.
panic(err)
}
}
return err
} }
func (s *SignService) VerifyRequest(req RequestMessage) error { func (s *SignService) VerifyRequest(req RequestMessage) error {
@ -207,7 +76,9 @@ func WrapResponse[T any](resp *T, err error) (*T, error) {
return new(T), nil return new(T), nil
} }
func isStatusSupported(req RequestMessage) bool { // IsStatusSupported returns true iff request version implies expecting status return.
// This allows us to handle protocol versions <=2.10 (API statuses was introduced in 2.11 only).
func IsStatusSupported(req RequestMessage) bool {
version := req.GetMetaHeader().GetVersion() version := req.GetMetaHeader().GetVersion()
mjr := version.GetMajor() mjr := version.GetMajor()
@ -223,22 +94,3 @@ func setStatusV2(resp ResponseMessage, err error) {
session.SetStatus(resp, apistatus.ToStatusV2(apistatus.ErrToStatus(err))) session.SetStatus(resp, apistatus.ToStatusV2(apistatus.ErrToStatus(err)))
} }
// signs response with private key via signature.SignServiceMessage.
// The signature error affects the result depending on the protocol version:
// - if status return is supported, panics since we cannot return the failed status, because it will not be signed;
// - otherwise, returns error in order to transport it directly.
func signResponse(key *ecdsa.PrivateKey, resp any, statusSupported bool) error {
err := signature.SignServiceMessage(key, resp)
if err != nil {
err = fmt.Errorf("could not sign response: %w", err)
if statusSupported {
// We can't pass this error as status code since response will be unsigned.
// Isn't expected in practice, so panic is ok here.
panic(err)
}
}
return err
}