From 3f7d3f8a86d5df90e222e64f16b735b81766407c Mon Sep 17 00:00:00 2001 From: Leonard Lyubich Date: Thu, 18 Jun 2020 15:24:17 +0300 Subject: [PATCH] service: make RequestData to provide BearerTokenSource interface --- service/bearer.go | 30 ++++++++++++++++++++++++++++++ service/bearer_test.go | 15 +++++++++++++++ service/sign_test.go | 6 ++++++ service/types.go | 1 + service/verify.go | 11 +++++++++++ service/verify_test.go | 10 ++++++++++ 6 files changed, 73 insertions(+) diff --git a/service/bearer.go b/service/bearer.go index 6013e03..dc556ce 100644 --- a/service/bearer.go +++ b/service/bearer.go @@ -12,6 +12,10 @@ type signedBearerToken struct { BearerToken } +type bearerMsgWrapper struct { + *BearerTokenMsg +} + const fixedBearerTokenDataSize = 0 + refs.OwnerIDSize + 8 @@ -124,3 +128,29 @@ func (m *BearerTokenMsg) SetOwnerKey(v []byte) { func (m *BearerTokenMsg) SetSignature(v []byte) { m.Signature = v } + +func wrapBearerTokenMsg(msg *BearerTokenMsg) bearerMsgWrapper { + return bearerMsgWrapper{ + BearerTokenMsg: msg, + } +} + +// ExpirationEpoch returns the result of ValidUntil field getter. +// +// If message is nil, 0 returns. +func (s bearerMsgWrapper) ExpirationEpoch() uint64 { + if s.BearerTokenMsg != nil { + return s.GetValidUntil() + } + + return 0 +} + +// SetExpirationEpoch passes argument to ValidUntil field setter. +// +// If message is nil, nothing changes. +func (s bearerMsgWrapper) SetExpirationEpoch(v uint64) { + if s.BearerTokenMsg != nil { + s.SetValidUntil(v) + } +} diff --git a/service/bearer_test.go b/service/bearer_test.go index 9ece9c8..381f190 100644 --- a/service/bearer_test.go +++ b/service/bearer_test.go @@ -194,3 +194,18 @@ func TestBearerTokenMsg_Setters(t *testing.T) { s.SetSignature(sig) require.Equal(t, sig, s.GetSignature()) } + +func TestBearerMsgWrapper_ExpirationEpoch(t *testing.T) { + s := wrapBearerTokenMsg(nil) + require.Zero(t, s.ExpirationEpoch()) + require.NotPanics(t, func() { + s.SetExpirationEpoch(1) + }) + + msg := new(BearerTokenMsg) + s = wrapBearerTokenMsg(msg) + + epoch := uint64(7) + s.SetExpirationEpoch(epoch) + require.Equal(t, epoch, s.ExpirationEpoch()) +} diff --git a/service/sign_test.go b/service/sign_test.go index ca469b8..8b67e5b 100644 --- a/service/sign_test.go +++ b/service/sign_test.go @@ -18,6 +18,8 @@ type testSignedDataSrc struct { sig []byte key *ecdsa.PublicKey token SessionToken + + bearer BearerToken } type testSignedDataReader struct { @@ -54,6 +56,10 @@ func (s testSignedDataSrc) GetSessionToken() SessionToken { return s.token } +func (s testSignedDataSrc) GetBearerToken() BearerToken { + return s.bearer +} + func (s testSignedDataReader) SignedDataSize() int { return len(s.data) } diff --git a/service/types.go b/service/types.go index 87c3a77..feba2e3 100644 --- a/service/types.go +++ b/service/types.go @@ -254,6 +254,7 @@ type DataWithSignKeySource interface { type RequestData interface { SignedDataSource SessionTokenSource + BearerTokenSource } // RequestSignedData is an interface of request information with signature write access. diff --git a/service/verify.go b/service/verify.go index 0673a01..9fbdfdf 100644 --- a/service/verify.go +++ b/service/verify.go @@ -103,3 +103,14 @@ func (t testCustomField) MarshalTo(data []byte) (int, error) { return 0, nil } // Marshal skip, it's for test usage only. func (t testCustomField) Marshal() ([]byte, error) { return nil, nil } + +// GetBearerToken returns wraps Bearer field and return BearerToken interface. +// +// If Bearer field value is nil, nil returns. +func (m RequestVerificationHeader) GetBearerToken() BearerToken { + if t := m.GetBearer(); t != nil { + return wrapBearerTokenMsg(t) + } + + return nil +} diff --git a/service/verify_test.go b/service/verify_test.go index 55ec65f..b42bb79 100644 --- a/service/verify_test.go +++ b/service/verify_test.go @@ -128,3 +128,13 @@ func TestRequestVerificationHeader_SetBearer(t *testing.T) { require.Equal(t, token, h.GetBearer()) } + +func TestRequestVerificationHeader_GetBearerToken(t *testing.T) { + s := new(RequestVerificationHeader) + + require.Nil(t, s.GetBearerToken()) + + bearer := new(BearerTokenMsg) + s.SetBearer(bearer) + require.Equal(t, wrapBearerTokenMsg(bearer), s.GetBearerToken()) +}