v2/signature: Unify request and response signing

Signed-off-by: Leonard Lyubich <leonard@nspcc.ru>
This commit is contained in:
Leonard Lyubich 2020-08-18 16:01:40 +03:00 committed by Stanislav Bogatyrev
parent 82110751e7
commit db97b782c0

View file

@ -10,12 +10,18 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type SignedRequest interface { type serviceRequest interface {
GetMetaHeader() *service.RequestMetaHeader GetMetaHeader() *service.RequestMetaHeader
GetVerificationHeader() *service.RequestVerificationHeader GetVerificationHeader() *service.RequestVerificationHeader
SetVerificationHeader(*service.RequestVerificationHeader) SetVerificationHeader(*service.RequestVerificationHeader)
} }
type serviceResponse interface {
GetMetaHeader() *service.ResponseMetaHeader
GetVerificationHeader() *service.ResponseVerificationHeader
SetVerificationHeader(*service.ResponseVerificationHeader)
}
type stableMarshaler interface { type stableMarshaler interface {
StableMarshal([]byte) ([]byte, error) StableMarshal([]byte) ([]byte, error)
StableSize() int StableSize() int
@ -25,12 +31,99 @@ type stableMarshalerWrapper struct {
sm stableMarshaler sm stableMarshaler
} }
type metaHeader interface {
stableMarshaler
getOrigin() metaHeader
}
type verificationHeader interface {
stableMarshaler
GetBodySignature() *service.Signature
SetBodySignature(*service.Signature)
GetMetaSignature() *service.Signature
SetMetaSignature(*service.Signature)
GetOriginSignature() *service.Signature
SetOriginSignature(*service.Signature)
setOrigin(stableMarshaler)
getOrigin() verificationHeader
}
type requestMetaHeader struct {
*service.RequestMetaHeader
}
type responseMetaHeader struct {
*service.ResponseMetaHeader
}
type requestVerificationHeader struct {
*service.RequestVerificationHeader
}
type responseVerificationHeader struct {
*service.ResponseVerificationHeader
}
func (h *requestMetaHeader) getOrigin() metaHeader {
return &requestMetaHeader{
RequestMetaHeader: h.GetOrigin(),
}
}
func (h *responseMetaHeader) getOrigin() metaHeader {
return &responseMetaHeader{
ResponseMetaHeader: h.GetOrigin(),
}
}
func (h *requestVerificationHeader) getOrigin() verificationHeader {
if origin := h.GetOrigin(); origin != nil {
return &requestVerificationHeader{
RequestVerificationHeader: origin,
}
}
return nil
}
func (h *requestVerificationHeader) setOrigin(m stableMarshaler) {
if m != nil {
h.SetOrigin(m.(*service.RequestVerificationHeader))
}
}
func (r *responseVerificationHeader) getOrigin() verificationHeader {
if origin := r.GetOrigin(); origin != nil {
return &responseVerificationHeader{
ResponseVerificationHeader: origin,
}
}
return nil
}
func (r *responseVerificationHeader) setOrigin(m stableMarshaler) {
if m != nil {
r.SetOrigin(m.(*service.ResponseVerificationHeader))
}
}
func (s stableMarshalerWrapper) ReadSignedData(buf []byte) ([]byte, error) { func (s stableMarshalerWrapper) ReadSignedData(buf []byte) ([]byte, error) {
if s.sm != nil {
return s.sm.StableMarshal(buf) return s.sm.StableMarshal(buf)
}
return nil, nil
} }
func (s stableMarshalerWrapper) SignedDataSize() int { func (s stableMarshalerWrapper) SignedDataSize() int {
if s.sm != nil {
return s.sm.StableSize() return s.sm.StableSize()
}
return 0
} }
func keySignatureHandler(s *service.Signature) signature.KeySignatureHandler { func keySignatureHandler(s *service.Signature) signature.KeySignatureHandler {
@ -46,48 +139,69 @@ func keySignatureSource(s *service.Signature) signature.KeySignatureSource {
} }
} }
func requestBody(req SignedRequest) stableMarshaler { func SignServiceMessage(key *ecdsa.PrivateKey, msg interface{}) error {
switch v := req.(type) { var (
case *accounting.BalanceRequest: body, meta, verifyOrigin stableMarshaler
return v.GetBody() verifyHdr verificationHeader
default: verifyHdrSetter func(verificationHeader)
panic(fmt.Sprintf("unknown request %T", req)) )
}
}
func SignRequest(key *ecdsa.PrivateKey, req SignedRequest) error { switch v := msg.(type) {
if req == nil { case nil:
return nil return nil
case serviceRequest:
body = serviceMessageBody(v)
meta = v.GetMetaHeader()
verifyHdr = &requestVerificationHeader{new(service.RequestVerificationHeader)}
verifyHdrSetter = func(h verificationHeader) {
v.SetVerificationHeader(h.(*requestVerificationHeader).RequestVerificationHeader)
} }
// create new level of matryoshka if h := v.GetVerificationHeader(); h != nil {
verifyHdr := new(service.RequestVerificationHeader) verifyOrigin = h
}
case serviceResponse:
body = serviceMessageBody(v)
meta = v.GetMetaHeader()
verifyHdr = &responseVerificationHeader{new(service.ResponseVerificationHeader)}
verifyHdrSetter = func(h verificationHeader) {
v.SetVerificationHeader(h.(*responseVerificationHeader).ResponseVerificationHeader)
}
// attach the previous matryoshka if h := v.GetVerificationHeader(); h != nil {
verifyHdr.SetOrigin(req.GetVerificationHeader()) verifyOrigin = h
}
default:
panic(fmt.Sprintf("unsupported service message %T", v))
}
// sign request body if verifyOrigin == nil {
if err := signRequestPart(key, requestBody(req), verifyHdr.SetBodySignature); err != nil { // sign service message body
return errors.Wrap(err, "could not sign request body") if err := signServiceMessagePart(key, body, verifyHdr.SetBodySignature); err != nil {
return errors.Wrap(err, "could not sign body")
}
} }
// sign meta header // sign meta header
if err := signRequestPart(key, req.GetMetaHeader(), verifyHdr.SetMetaSignature); err != nil { if err := signServiceMessagePart(key, meta, verifyHdr.SetMetaSignature); err != nil {
return errors.Wrap(err, "could not sign request meta header") return errors.Wrap(err, "could not sign meta header")
} }
// sign verification header origin // sign verification header origin
if err := signRequestPart(key, verifyHdr.GetOrigin(), verifyHdr.SetOriginSignature); err != nil { if err := signServiceMessagePart(key, verifyOrigin, verifyHdr.SetOriginSignature); err != nil {
return errors.Wrap(err, "could not sign origin of request verification header") return errors.Wrap(err, "could not sign origin of verification header")
} }
// make a new top of the matryoshka // wrap origin verification header
req.SetVerificationHeader(verifyHdr) verifyHdr.setOrigin(verifyOrigin)
// update matryoshka verification header
verifyHdrSetter(verifyHdr)
return nil return nil
} }
func signRequestPart(key *ecdsa.PrivateKey, part stableMarshaler, sigWrite func(*service.Signature)) error { func signServiceMessagePart(key *ecdsa.PrivateKey, part stableMarshaler, sigWrite func(*service.Signature)) error {
sig := new(service.Signature) sig := new(service.Signature)
// sign part // sign part
@ -105,34 +219,78 @@ func signRequestPart(key *ecdsa.PrivateKey, part stableMarshaler, sigWrite func(
return nil return nil
} }
func VerifyRequest(req SignedRequest) error { func VerifyServiceMessage(msg interface{}) error {
verifyHdr := req.GetVerificationHeader() var (
meta metaHeader
verify verificationHeader
)
// verify body signature switch v := msg.(type) {
if err := verifyRequestPart(requestBody(req), verifyHdr.GetBodySignature); err != nil { case nil:
return errors.Wrap(err, "could not verify request body") return nil
case serviceRequest:
meta = &requestMetaHeader{
RequestMetaHeader: v.GetMetaHeader(),
} }
// verify meta header verify = &requestVerificationHeader{
if err := verifyRequestPart(req.GetMetaHeader(), verifyHdr.GetMetaSignature); err != nil { RequestVerificationHeader: v.GetVerificationHeader(),
return errors.Wrap(err, "could not verify request meta header") }
case serviceResponse:
meta = &responseMetaHeader{
ResponseMetaHeader: v.GetMetaHeader(),
} }
// verify verification header origin verify = &responseVerificationHeader{
if err := verifyRequestPart(verifyHdr.GetOrigin(), verifyHdr.GetOriginSignature); err != nil { ResponseVerificationHeader: v.GetVerificationHeader(),
return errors.Wrap(err, "could not verify origin of request verification header") }
default:
panic(fmt.Sprintf("unsupported service message %T", v))
}
return verifyMatryoshkaLevel(serviceMessageBody(msg), meta, verify)
}
func verifyMatryoshkaLevel(body stableMarshaler, meta metaHeader, verify verificationHeader) error {
if err := verifyServiceMessagePart(meta, verify.GetMetaSignature); err != nil {
return errors.Wrap(err, "could not verify meta header")
}
origin := verify.getOrigin()
if err := verifyServiceMessagePart(origin, verify.GetOriginSignature); err != nil {
return errors.Wrap(err, "could not verify origin of verification header")
}
if origin == nil {
if err := verifyServiceMessagePart(body, verify.GetBodySignature); err != nil {
return errors.Wrap(err, "could not verify body")
} }
return nil return nil
}
if verify.GetBodySignature() != nil {
return errors.New("body signature at the matryoshka upper level")
}
return verifyMatryoshkaLevel(body, meta.getOrigin(), origin)
} }
func verifyRequestPart(part stableMarshaler, sigRdr func() *service.Signature) error { func verifyServiceMessagePart(part stableMarshaler, sigRdr func() *service.Signature) error {
if err := signature.VerifyDataWithSource( return signature.VerifyDataWithSource(
&stableMarshalerWrapper{part}, &stableMarshalerWrapper{part},
keySignatureSource(sigRdr()), keySignatureSource(sigRdr()),
); err != nil { )
return err }
}
func serviceMessageBody(req interface{}) stableMarshaler {
return nil switch v := req.(type) {
case *accounting.BalanceRequest:
return v.GetBody()
case *accounting.BalanceResponse:
return v.GetBody()
default:
panic(fmt.Sprintf("unsupported service message %T", req))
}
} }