diff --git a/v2/signature/sign.go b/v2/signature/sign.go index d0fddc56..e17e90ee 100644 --- a/v2/signature/sign.go +++ b/v2/signature/sign.go @@ -10,12 +10,18 @@ import ( "github.com/pkg/errors" ) -type SignedRequest interface { +type serviceRequest interface { GetMetaHeader() *service.RequestMetaHeader GetVerificationHeader() *service.RequestVerificationHeader SetVerificationHeader(*service.RequestVerificationHeader) } +type serviceResponse interface { + GetMetaHeader() *service.ResponseMetaHeader + GetVerificationHeader() *service.ResponseVerificationHeader + SetVerificationHeader(*service.ResponseVerificationHeader) +} + type stableMarshaler interface { StableMarshal([]byte) ([]byte, error) StableSize() int @@ -25,12 +31,99 @@ type stableMarshalerWrapper struct { sm stableMarshaler } +type metaHeader interface { + stableMarshaler + getOrigin() metaHeader +} + +type verificationHeader interface { + stableMarshaler + + GetBodySignature() *service.Signature + SetBodySignature(*service.Signature) + GetMetaSignature() *service.Signature + SetMetaSignature(*service.Signature) + GetOriginSignature() *service.Signature + SetOriginSignature(*service.Signature) + + setOrigin(stableMarshaler) + getOrigin() verificationHeader +} + +type requestMetaHeader struct { + *service.RequestMetaHeader +} + +type responseMetaHeader struct { + *service.ResponseMetaHeader +} + +type requestVerificationHeader struct { + *service.RequestVerificationHeader +} + +type responseVerificationHeader struct { + *service.ResponseVerificationHeader +} + +func (h *requestMetaHeader) getOrigin() metaHeader { + return &requestMetaHeader{ + RequestMetaHeader: h.GetOrigin(), + } +} + +func (h *responseMetaHeader) getOrigin() metaHeader { + return &responseMetaHeader{ + ResponseMetaHeader: h.GetOrigin(), + } +} + +func (h *requestVerificationHeader) getOrigin() verificationHeader { + if origin := h.GetOrigin(); origin != nil { + return &requestVerificationHeader{ + RequestVerificationHeader: origin, + } + } + + return nil +} + +func (h *requestVerificationHeader) setOrigin(m stableMarshaler) { + if m != nil { + h.SetOrigin(m.(*service.RequestVerificationHeader)) + } +} + +func (r *responseVerificationHeader) getOrigin() verificationHeader { + if origin := r.GetOrigin(); origin != nil { + return &responseVerificationHeader{ + ResponseVerificationHeader: origin, + } + } + + return nil +} + +func (r *responseVerificationHeader) setOrigin(m stableMarshaler) { + if m != nil { + r.SetOrigin(m.(*service.ResponseVerificationHeader)) + } +} + func (s stableMarshalerWrapper) ReadSignedData(buf []byte) ([]byte, error) { - return s.sm.StableMarshal(buf) + if s.sm != nil { + return s.sm.StableMarshal(buf) + } + + return nil, nil } func (s stableMarshalerWrapper) SignedDataSize() int { - return s.sm.StableSize() + if s.sm != nil { + return s.sm.StableSize() + } + + return 0 } func keySignatureHandler(s *service.Signature) signature.KeySignatureHandler { @@ -46,48 +139,69 @@ func keySignatureSource(s *service.Signature) signature.KeySignatureSource { } } -func requestBody(req SignedRequest) stableMarshaler { - switch v := req.(type) { - case *accounting.BalanceRequest: - return v.GetBody() - default: - panic(fmt.Sprintf("unknown request %T", req)) - } -} +func SignServiceMessage(key *ecdsa.PrivateKey, msg interface{}) error { + var ( + body, meta, verifyOrigin stableMarshaler + verifyHdr verificationHeader + verifyHdrSetter func(verificationHeader) + ) -func SignRequest(key *ecdsa.PrivateKey, req SignedRequest) error { - if req == nil { + switch v := msg.(type) { + case nil: return nil + case serviceRequest: + body = serviceMessageBody(v) + meta = v.GetMetaHeader() + verifyHdr = &requestVerificationHeader{new(service.RequestVerificationHeader)} + verifyHdrSetter = func(h verificationHeader) { + v.SetVerificationHeader(h.(*requestVerificationHeader).RequestVerificationHeader) + } + + if h := v.GetVerificationHeader(); h != nil { + verifyOrigin = h + } + case serviceResponse: + body = serviceMessageBody(v) + meta = v.GetMetaHeader() + verifyHdr = &responseVerificationHeader{new(service.ResponseVerificationHeader)} + verifyHdrSetter = func(h verificationHeader) { + v.SetVerificationHeader(h.(*responseVerificationHeader).ResponseVerificationHeader) + } + + if h := v.GetVerificationHeader(); h != nil { + verifyOrigin = h + } + default: + panic(fmt.Sprintf("unsupported service message %T", v)) } - // create new level of matryoshka - verifyHdr := new(service.RequestVerificationHeader) - - // attach the previous matryoshka - verifyHdr.SetOrigin(req.GetVerificationHeader()) - - // sign request body - if err := signRequestPart(key, requestBody(req), verifyHdr.SetBodySignature); err != nil { - return errors.Wrap(err, "could not sign request body") + if verifyOrigin == nil { + // sign service message body + if err := signServiceMessagePart(key, body, verifyHdr.SetBodySignature); err != nil { + return errors.Wrap(err, "could not sign body") + } } // sign meta header - if err := signRequestPart(key, req.GetMetaHeader(), verifyHdr.SetMetaSignature); err != nil { - return errors.Wrap(err, "could not sign request meta header") + if err := signServiceMessagePart(key, meta, verifyHdr.SetMetaSignature); err != nil { + return errors.Wrap(err, "could not sign meta header") } // sign verification header origin - if err := signRequestPart(key, verifyHdr.GetOrigin(), verifyHdr.SetOriginSignature); err != nil { - return errors.Wrap(err, "could not sign origin of request verification header") + if err := signServiceMessagePart(key, verifyOrigin, verifyHdr.SetOriginSignature); err != nil { + return errors.Wrap(err, "could not sign origin of verification header") } - // make a new top of the matryoshka - req.SetVerificationHeader(verifyHdr) + // wrap origin verification header + verifyHdr.setOrigin(verifyOrigin) + + // update matryoshka verification header + verifyHdrSetter(verifyHdr) return nil } -func signRequestPart(key *ecdsa.PrivateKey, part stableMarshaler, sigWrite func(*service.Signature)) error { +func signServiceMessagePart(key *ecdsa.PrivateKey, part stableMarshaler, sigWrite func(*service.Signature)) error { sig := new(service.Signature) // sign part @@ -105,34 +219,78 @@ func signRequestPart(key *ecdsa.PrivateKey, part stableMarshaler, sigWrite func( return nil } -func VerifyRequest(req SignedRequest) error { - verifyHdr := req.GetVerificationHeader() +func VerifyServiceMessage(msg interface{}) error { + var ( + meta metaHeader + verify verificationHeader + ) - // verify body signature - if err := verifyRequestPart(requestBody(req), verifyHdr.GetBodySignature); err != nil { - return errors.Wrap(err, "could not verify request body") + switch v := msg.(type) { + case nil: + return nil + case serviceRequest: + meta = &requestMetaHeader{ + RequestMetaHeader: v.GetMetaHeader(), + } + + verify = &requestVerificationHeader{ + RequestVerificationHeader: v.GetVerificationHeader(), + } + case serviceResponse: + meta = &responseMetaHeader{ + ResponseMetaHeader: v.GetMetaHeader(), + } + + verify = &responseVerificationHeader{ + ResponseVerificationHeader: v.GetVerificationHeader(), + } + default: + panic(fmt.Sprintf("unsupported service message %T", v)) } - // verify meta header - if err := verifyRequestPart(req.GetMetaHeader(), verifyHdr.GetMetaSignature); err != nil { - return errors.Wrap(err, "could not verify request meta header") - } - - // verify verification header origin - if err := verifyRequestPart(verifyHdr.GetOrigin(), verifyHdr.GetOriginSignature); err != nil { - return errors.Wrap(err, "could not verify origin of request verification header") - } - - return nil + return verifyMatryoshkaLevel(serviceMessageBody(msg), meta, verify) } -func verifyRequestPart(part stableMarshaler, sigRdr func() *service.Signature) error { - if err := signature.VerifyDataWithSource( +func verifyMatryoshkaLevel(body stableMarshaler, meta metaHeader, verify verificationHeader) error { + if err := verifyServiceMessagePart(meta, verify.GetMetaSignature); err != nil { + return errors.Wrap(err, "could not verify meta header") + } + + origin := verify.getOrigin() + + if err := verifyServiceMessagePart(origin, verify.GetOriginSignature); err != nil { + return errors.Wrap(err, "could not verify origin of verification header") + } + + if origin == nil { + if err := verifyServiceMessagePart(body, verify.GetBodySignature); err != nil { + return errors.Wrap(err, "could not verify body") + } + + return nil + } + + if verify.GetBodySignature() != nil { + return errors.New("body signature at the matryoshka upper level") + } + + return verifyMatryoshkaLevel(body, meta.getOrigin(), origin) +} + +func verifyServiceMessagePart(part stableMarshaler, sigRdr func() *service.Signature) error { + return signature.VerifyDataWithSource( &stableMarshalerWrapper{part}, keySignatureSource(sigRdr()), - ); err != nil { - return err - } - - return nil + ) +} + +func serviceMessageBody(req interface{}) stableMarshaler { + switch v := req.(type) { + case *accounting.BalanceRequest: + return v.GetBody() + case *accounting.BalanceResponse: + return v.GetBody() + default: + panic(fmt.Sprintf("unsupported service message %T", req)) + } }