[#6] services/util: Remove remaining stream wrappers
Signed-off-by: Evgenii Stratonikov <e.stratonikov@yadro.com>
This commit is contained in:
parent
372160d048
commit
c2617baf63
6 changed files with 139 additions and 283 deletions
|
@ -1,7 +1,6 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -21,163 +20,23 @@ type ResponseMessage interface {
|
|||
SetMetaHeader(*session.ResponseMetaHeader)
|
||||
}
|
||||
|
||||
type UnaryHandler func(context.Context, any) (ResponseMessage, error)
|
||||
|
||||
type SignService struct {
|
||||
key *ecdsa.PrivateKey
|
||||
}
|
||||
|
||||
type ResponseMessageWriter func(ResponseMessage) error
|
||||
|
||||
type ServerStreamHandler func(context.Context, any) (ResponseMessageReader, error)
|
||||
|
||||
type ResponseMessageReader func() (ResponseMessage, error)
|
||||
|
||||
var ErrAbortStream = errors.New("abort message stream")
|
||||
|
||||
type ResponseConstructor func() ResponseMessage
|
||||
|
||||
type RequestMessageWriter func(context.Context, any) error
|
||||
|
||||
type ClientStreamCloser func(context.Context) (ResponseMessage, error)
|
||||
|
||||
type RequestMessageStreamer struct {
|
||||
key *ecdsa.PrivateKey
|
||||
|
||||
send RequestMessageWriter
|
||||
|
||||
close ClientStreamCloser
|
||||
|
||||
respCons ResponseConstructor
|
||||
|
||||
statusSupported bool
|
||||
|
||||
sendErr error
|
||||
}
|
||||
|
||||
func NewUnarySignService(key *ecdsa.PrivateKey) *SignService {
|
||||
return &SignService{
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RequestMessageStreamer) Send(ctx context.Context, req any) error {
|
||||
// req argument should be strengthen with type RequestMessage
|
||||
s.statusSupported = isStatusSupported(req.(RequestMessage)) // panic is OK here for now
|
||||
|
||||
var err error
|
||||
|
||||
// verify request signatures
|
||||
if err = signature.VerifyServiceMessage(req); err != nil {
|
||||
err = fmt.Errorf("could not verify request: %w", err)
|
||||
} else {
|
||||
err = s.send(ctx, req)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !s.statusSupported {
|
||||
return err
|
||||
}
|
||||
|
||||
s.sendErr = err
|
||||
|
||||
return ErrAbortStream
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RequestMessageStreamer) CloseAndRecv(ctx context.Context) (ResponseMessage, error) {
|
||||
var (
|
||||
resp ResponseMessage
|
||||
err error
|
||||
)
|
||||
|
||||
if s.sendErr != nil {
|
||||
err = s.sendErr
|
||||
} else {
|
||||
resp, err = s.close(ctx)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("could not close stream and receive response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !s.statusSupported {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp = s.respCons()
|
||||
|
||||
setStatusV2(resp, err)
|
||||
}
|
||||
|
||||
if err = signResponse(s.key, resp, s.statusSupported); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *SignService) CreateRequestStreamer(sender RequestMessageWriter, closer ClientStreamCloser, blankResp ResponseConstructor) *RequestMessageStreamer {
|
||||
return &RequestMessageStreamer{
|
||||
key: s.key,
|
||||
send: sender,
|
||||
close: closer,
|
||||
|
||||
respCons: blankResp,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SignService) HandleServerStreamRequest(
|
||||
req any,
|
||||
respWriter ResponseMessageWriter,
|
||||
blankResp ResponseConstructor,
|
||||
respWriterCaller func(ResponseMessageWriter) error,
|
||||
) error {
|
||||
// handle protocol versions <=2.10 (API statuses was introduced in 2.11 only)
|
||||
|
||||
// req argument should be strengthen with type RequestMessage
|
||||
statusSupported := isStatusSupported(req.(RequestMessage)) // panic is OK here for now
|
||||
|
||||
var err error
|
||||
|
||||
// verify request signatures
|
||||
if err = signature.VerifyServiceMessage(req); err != nil {
|
||||
err = fmt.Errorf("could not verify request: %w", err)
|
||||
} else {
|
||||
err = respWriterCaller(func(resp ResponseMessage) error {
|
||||
if err := signResponse(s.key, resp, statusSupported); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return respWriter(resp)
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !statusSupported {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := blankResp()
|
||||
|
||||
setStatusV2(resp, err)
|
||||
|
||||
_ = signResponse(s.key, resp, false) // panics or returns nil with false arg
|
||||
|
||||
return respWriter(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SignService) SignResponse(req RequestMessage, resp ResponseMessage, err error) error {
|
||||
// handle protocol versions <=2.10 (API statuses was introduced in 2.11 only)
|
||||
|
||||
// req argument should be strengthen with type RequestMessage
|
||||
statusSupported := isStatusSupported(req)
|
||||
|
||||
// SignResponse response with private key via signature.SignServiceMessage.
|
||||
// The signature error affects the result depending on the protocol version:
|
||||
// - if status return is supported, panics since we cannot return the failed status, because it will not be signed.
|
||||
// - otherwise, returns error in order to transport it directly.
|
||||
func (s *SignService) SignResponse(statusSupported bool, resp ResponseMessage, err error) error {
|
||||
if err != nil {
|
||||
if !statusSupported {
|
||||
return err
|
||||
|
@ -186,8 +45,18 @@ func (s *SignService) SignResponse(req RequestMessage, resp ResponseMessage, err
|
|||
setStatusV2(resp, err)
|
||||
}
|
||||
|
||||
// sign the response
|
||||
return signResponse(s.key, resp, statusSupported)
|
||||
err = signature.SignServiceMessage(s.key, resp)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("could not sign response: %w", err)
|
||||
|
||||
if statusSupported {
|
||||
// We can't pass this error as status code since response will be unsigned.
|
||||
// Isn't expected in practice, so panic is ok here.
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SignService) VerifyRequest(req RequestMessage) error {
|
||||
|
@ -207,7 +76,9 @@ func EnsureNonNilResponse[T any](resp *T, err error) (*T, error) {
|
|||
return new(T), err
|
||||
}
|
||||
|
||||
func isStatusSupported(req RequestMessage) bool {
|
||||
// IsStatusSupported returns true iff request version implies expecting status return.
|
||||
// This allows us to handle protocol versions <=2.10 (API statuses was introduced in 2.11 only).
|
||||
func IsStatusSupported(req RequestMessage) bool {
|
||||
version := req.GetMetaHeader().GetVersion()
|
||||
|
||||
mjr := version.GetMajor()
|
||||
|
@ -223,22 +94,3 @@ func setStatusV2(resp ResponseMessage, err error) {
|
|||
|
||||
session.SetStatus(resp, apistatus.ToStatusV2(apistatus.ErrToStatus(err)))
|
||||
}
|
||||
|
||||
// signs response with private key via signature.SignServiceMessage.
|
||||
// The signature error affects the result depending on the protocol version:
|
||||
// - if status return is supported, panics since we cannot return the failed status, because it will not be signed;
|
||||
// - otherwise, returns error in order to transport it directly.
|
||||
func signResponse(key *ecdsa.PrivateKey, resp any, statusSupported bool) error {
|
||||
err := signature.SignServiceMessage(key, resp)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("could not sign response: %w", err)
|
||||
|
||||
if statusSupported {
|
||||
// We can't pass this error as status code since response will be unsigned.
|
||||
// Isn't expected in practice, so panic is ok here.
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue