diff --git a/session/errors.go b/session/errors.go index 3a9c129..d35bed4 100644 --- a/session/errors.go +++ b/session/errors.go @@ -13,3 +13,7 @@ const ErrNilGPRCClientConn = internal.Error("gRPC client connection is nil") // ErrPrivateTokenNotFound is returned when addressed private token was // not found in storage. const ErrPrivateTokenNotFound = internal.Error("private token not found") + +// ErrNilPrivateToken is returned by functions that expect a non-nil +// PrivateToken, but received nil. +const ErrNilPrivateToken = internal.Error("private token is nil") diff --git a/session/private.go b/session/private.go index 42bb205..bb9242f 100644 --- a/session/private.go +++ b/session/private.go @@ -30,14 +30,26 @@ func NewPrivateToken(validUntil uint64) (PrivateToken, error) { }, nil } -// Sign signs data with session private key. -func (t *pToken) Sign(data []byte) ([]byte, error) { - return crypto.Sign(t.sessionKey, data) +// PublicSessionToken returns a binary representation of session public key. +// +// If passed PrivateToken is nil, ErrNilPrivateToken returns. +// If passed PrivateToken carries nil private key, crypto.ErrEmptyPrivateKey returns. +func PublicSessionToken(pToken PrivateToken) ([]byte, error) { + if pToken == nil { + return nil, ErrNilPrivateToken + } + + sk := pToken.PrivateKey() + if sk == nil { + return nil, crypto.ErrEmptyPrivateKey + } + + return crypto.MarshalPublicKey(&sk.PublicKey), nil } -// PublicKey returns a binary representation of the session public key. -func (t *pToken) PublicKey() []byte { - return crypto.MarshalPublicKey(&t.sessionKey.PublicKey) +// PrivateKey is a session private key getter. +func (t *pToken) PrivateKey() *ecdsa.PrivateKey { + return t.sessionKey } func (t *pToken) Expired(epoch uint64) bool { diff --git a/session/private_test.go b/session/private_test.go index 8097b97..c6f8125 100644 --- a/session/private_test.go +++ b/session/private_test.go @@ -1,35 +1,17 @@ package session import ( - "crypto/rand" "testing" crypto "github.com/nspcc-dev/neofs-crypto" "github.com/stretchr/testify/require" ) -func TestPrivateToken(t *testing.T) { +func TestPToken_PrivateKey(t *testing.T) { // create new private token pToken, err := NewPrivateToken(0) require.NoError(t, err) - - // generate data to sign - data := make([]byte, 10) - _, err = rand.Read(data) - require.NoError(t, err) - - // sign data via private token - sig, err := pToken.Sign(data) - require.NoError(t, err) - - // check signature - require.NoError(t, - crypto.Verify( - crypto.UnmarshalPublicKey(pToken.PublicKey()), - data, - sig, - ), - ) + require.NotNil(t, pToken.PrivateKey()) } func TestPToken_Expired(t *testing.T) { @@ -68,3 +50,27 @@ func TestPrivateTokenKey_SetTokenID(t *testing.T) { require.Equal(t, tokenID, s.token) } + +func TestPublicSessionToken(t *testing.T) { + var err error + + // nil PrivateToken + _, err = PublicSessionToken(nil) + require.EqualError(t, err, ErrNilPrivateToken.Error()) + + // empty private key + var pToken PrivateToken = new(pToken) + _, err = PublicSessionToken(pToken) + require.EqualError(t, err, crypto.ErrEmptyPrivateKey.Error()) + + // correct PrivateToken + pToken, err = NewPrivateToken(0) + require.NoError(t, err) + + key := pToken.PrivateKey() + require.NotNil(t, key) + + res, err := PublicSessionToken(pToken) + require.NoError(t, err) + require.Equal(t, res, crypto.MarshalPublicKey(&key.PublicKey)) +} diff --git a/session/types.go b/session/types.go index ee13b92..95a0065 100644 --- a/session/types.go +++ b/session/types.go @@ -10,14 +10,8 @@ import ( // PrivateToken is an interface of session private part. type PrivateToken interface { - // PublicKey must return a binary representation of session public key. - PublicKey() []byte - - // Sign must return the signature of passed data. - // - // Resulting signature must be verified by crypto.Verify function - // with the session public key. - Sign([]byte) ([]byte, error) + // PrivateKey must return session private key. + PrivateKey() *ecdsa.PrivateKey // Expired must return true if and only if private token is expired in the given epoch number. Expired(uint64) bool