From e08e3d6c00e96dc6b7964e68790c247c1aac2217 Mon Sep 17 00:00:00 2001
From: Denis Kirillov <denis@nspcc.ru>
Date: Mon, 25 Oct 2021 16:24:43 +0300
Subject: [PATCH] [#38] Add session cache tests

Signed-off-by: Denis Kirillov <denis@nspcc.ru>
---
 pool/pool.go      |  62 +++++++++++++--------------
 pool/pool_test.go | 107 ++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 138 insertions(+), 31 deletions(-)

diff --git a/pool/pool.go b/pool/pool.go
index dff8711a..451fd940 100644
--- a/pool/pool.go
+++ b/pool/pool.go
@@ -93,21 +93,21 @@ type Pool interface {
 	WaitForContainerPresence(context.Context, *cid.ID, *ContainerPollingParams) error
 	Close()
 
-	PutObjectParam(ctx context.Context, params *client.PutObjectParams, callParam CallParam) (*object.ID, error)
-	DeleteObjectParam(ctx context.Context, params *client.DeleteObjectParams, callParam CallParam) error
-	GetObjectParam(ctx context.Context, params *client.GetObjectParams, callParam CallParam) (*object.Object, error)
-	GetObjectHeaderParam(ctx context.Context, params *client.ObjectHeaderParams, callParam CallParam) (*object.Object, error)
-	ObjectPayloadRangeDataParam(ctx context.Context, params *client.RangeDataParams, callParam CallParam) ([]byte, error)
-	ObjectPayloadRangeSHA256Param(ctx context.Context, params *client.RangeChecksumParams, callParam CallParam) ([][32]byte, error)
-	ObjectPayloadRangeTZParam(ctx context.Context, params *client.RangeChecksumParams, callParam CallParam) ([][64]byte, error)
-	SearchObjectParam(ctx context.Context, params *client.SearchObjectParams, callParam CallParam) ([]*object.ID, error)
-	PutContainerParam(ctx context.Context, cnr *container.Container, callParam CallParam) (*cid.ID, error)
-	GetContainerParam(ctx context.Context, cid *cid.ID, callParam CallParam) (*container.Container, error)
-	ListContainersParam(ctx context.Context, ownerID *owner.ID, callParam CallParam) ([]*cid.ID, error)
-	DeleteContainerParam(ctx context.Context, cid *cid.ID, callParam CallParam) error
-	GetEACLParam(ctx context.Context, cid *cid.ID, callParam CallParam) (*client.EACLWithSignature, error)
-	SetEACLParam(ctx context.Context, table *eacl.Table, callParam CallParam) error
-	AnnounceContainerUsedSpaceParam(ctx context.Context, announce []container.UsedSpaceAnnouncement, callParam CallParam) error
+	PutObjectParam(ctx context.Context, params *client.PutObjectParams, callParam *CallParam) (*object.ID, error)
+	DeleteObjectParam(ctx context.Context, params *client.DeleteObjectParams, callParam *CallParam) error
+	GetObjectParam(ctx context.Context, params *client.GetObjectParams, callParam *CallParam) (*object.Object, error)
+	GetObjectHeaderParam(ctx context.Context, params *client.ObjectHeaderParams, callParam *CallParam) (*object.Object, error)
+	ObjectPayloadRangeDataParam(ctx context.Context, params *client.RangeDataParams, callParam *CallParam) ([]byte, error)
+	ObjectPayloadRangeSHA256Param(ctx context.Context, params *client.RangeChecksumParams, callParam *CallParam) ([][32]byte, error)
+	ObjectPayloadRangeTZParam(ctx context.Context, params *client.RangeChecksumParams, callParam *CallParam) ([][64]byte, error)
+	SearchObjectParam(ctx context.Context, params *client.SearchObjectParams, callParam *CallParam) ([]*object.ID, error)
+	PutContainerParam(ctx context.Context, cnr *container.Container, callParam *CallParam) (*cid.ID, error)
+	GetContainerParam(ctx context.Context, cid *cid.ID, callParam *CallParam) (*container.Container, error)
+	ListContainersParam(ctx context.Context, ownerID *owner.ID, callParam *CallParam) ([]*cid.ID, error)
+	DeleteContainerParam(ctx context.Context, cid *cid.ID, callParam *CallParam) error
+	GetEACLParam(ctx context.Context, cid *cid.ID, callParam *CallParam) (*client.EACLWithSignature, error)
+	SetEACLParam(ctx context.Context, table *eacl.Table, callParam *CallParam) error
+	AnnounceContainerUsedSpaceParam(ctx context.Context, announce []container.UsedSpaceAnnouncement, callParam *CallParam) error
 }
 
 type clientPack struct {
@@ -319,7 +319,7 @@ func formCacheKey(address string, key *ecdsa.PrivateKey) string {
 	return address + k.String()
 }
 
-func (p *pool) connParam(ctx context.Context, param CallParam) (*clientPack, []client.CallOption, error) {
+func (p *pool) connParam(ctx context.Context, param *CallParam) (*clientPack, []client.CallOption, error) {
 	cp, err := p.connection()
 	if err != nil {
 		return nil, nil, err
@@ -474,7 +474,7 @@ func (p *pool) checkSessionTokenErr(err error, address string) {
 	}
 }
 
-func (p *pool) PutObjectParam(ctx context.Context, params *client.PutObjectParams, callParam CallParam) (*object.ID, error) {
+func (p *pool) PutObjectParam(ctx context.Context, params *client.PutObjectParams, callParam *CallParam) (*object.ID, error) {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return nil, err
@@ -484,7 +484,7 @@ func (p *pool) PutObjectParam(ctx context.Context, params *client.PutObjectParam
 	return res, err
 }
 
-func (p *pool) DeleteObjectParam(ctx context.Context, params *client.DeleteObjectParams, callParam CallParam) error {
+func (p *pool) DeleteObjectParam(ctx context.Context, params *client.DeleteObjectParams, callParam *CallParam) error {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return err
@@ -494,7 +494,7 @@ func (p *pool) DeleteObjectParam(ctx context.Context, params *client.DeleteObjec
 	return err
 }
 
-func (p *pool) GetObjectParam(ctx context.Context, params *client.GetObjectParams, callParam CallParam) (*object.Object, error) {
+func (p *pool) GetObjectParam(ctx context.Context, params *client.GetObjectParams, callParam *CallParam) (*object.Object, error) {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return nil, err
@@ -504,7 +504,7 @@ func (p *pool) GetObjectParam(ctx context.Context, params *client.GetObjectParam
 	return res, err
 }
 
-func (p *pool) GetObjectHeaderParam(ctx context.Context, params *client.ObjectHeaderParams, callParam CallParam) (*object.Object, error) {
+func (p *pool) GetObjectHeaderParam(ctx context.Context, params *client.ObjectHeaderParams, callParam *CallParam) (*object.Object, error) {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return nil, err
@@ -514,7 +514,7 @@ func (p *pool) GetObjectHeaderParam(ctx context.Context, params *client.ObjectHe
 	return res, err
 }
 
-func (p *pool) ObjectPayloadRangeDataParam(ctx context.Context, params *client.RangeDataParams, callParam CallParam) ([]byte, error) {
+func (p *pool) ObjectPayloadRangeDataParam(ctx context.Context, params *client.RangeDataParams, callParam *CallParam) ([]byte, error) {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return nil, err
@@ -524,7 +524,7 @@ func (p *pool) ObjectPayloadRangeDataParam(ctx context.Context, params *client.R
 	return res, err
 }
 
-func (p *pool) ObjectPayloadRangeSHA256Param(ctx context.Context, params *client.RangeChecksumParams, callParam CallParam) ([][32]byte, error) {
+func (p *pool) ObjectPayloadRangeSHA256Param(ctx context.Context, params *client.RangeChecksumParams, callParam *CallParam) ([][32]byte, error) {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return nil, err
@@ -534,7 +534,7 @@ func (p *pool) ObjectPayloadRangeSHA256Param(ctx context.Context, params *client
 	return res, err
 }
 
-func (p *pool) ObjectPayloadRangeTZParam(ctx context.Context, params *client.RangeChecksumParams, callParam CallParam) ([][64]byte, error) {
+func (p *pool) ObjectPayloadRangeTZParam(ctx context.Context, params *client.RangeChecksumParams, callParam *CallParam) ([][64]byte, error) {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return nil, err
@@ -544,7 +544,7 @@ func (p *pool) ObjectPayloadRangeTZParam(ctx context.Context, params *client.Ran
 	return res, err
 }
 
-func (p *pool) SearchObjectParam(ctx context.Context, params *client.SearchObjectParams, callParam CallParam) ([]*object.ID, error) {
+func (p *pool) SearchObjectParam(ctx context.Context, params *client.SearchObjectParams, callParam *CallParam) ([]*object.ID, error) {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return nil, err
@@ -554,7 +554,7 @@ func (p *pool) SearchObjectParam(ctx context.Context, params *client.SearchObjec
 	return res, err
 }
 
-func (p *pool) PutContainerParam(ctx context.Context, cnr *container.Container, callParam CallParam) (*cid.ID, error) {
+func (p *pool) PutContainerParam(ctx context.Context, cnr *container.Container, callParam *CallParam) (*cid.ID, error) {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return nil, err
@@ -564,7 +564,7 @@ func (p *pool) PutContainerParam(ctx context.Context, cnr *container.Container,
 	return res, err
 }
 
-func (p *pool) GetContainerParam(ctx context.Context, cid *cid.ID, callParam CallParam) (*container.Container, error) {
+func (p *pool) GetContainerParam(ctx context.Context, cid *cid.ID, callParam *CallParam) (*container.Container, error) {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return nil, err
@@ -574,7 +574,7 @@ func (p *pool) GetContainerParam(ctx context.Context, cid *cid.ID, callParam Cal
 	return res, err
 }
 
-func (p *pool) ListContainersParam(ctx context.Context, ownerID *owner.ID, callParam CallParam) ([]*cid.ID, error) {
+func (p *pool) ListContainersParam(ctx context.Context, ownerID *owner.ID, callParam *CallParam) ([]*cid.ID, error) {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return nil, err
@@ -584,7 +584,7 @@ func (p *pool) ListContainersParam(ctx context.Context, ownerID *owner.ID, callP
 	return res, err
 }
 
-func (p *pool) DeleteContainerParam(ctx context.Context, cid *cid.ID, callParam CallParam) error {
+func (p *pool) DeleteContainerParam(ctx context.Context, cid *cid.ID, callParam *CallParam) error {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return err
@@ -594,7 +594,7 @@ func (p *pool) DeleteContainerParam(ctx context.Context, cid *cid.ID, callParam
 	return err
 }
 
-func (p *pool) GetEACLParam(ctx context.Context, cid *cid.ID, callParam CallParam) (*client.EACLWithSignature, error) {
+func (p *pool) GetEACLParam(ctx context.Context, cid *cid.ID, callParam *CallParam) (*client.EACLWithSignature, error) {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return nil, err
@@ -604,7 +604,7 @@ func (p *pool) GetEACLParam(ctx context.Context, cid *cid.ID, callParam CallPara
 	return res, err
 }
 
-func (p *pool) SetEACLParam(ctx context.Context, table *eacl.Table, callParam CallParam) error {
+func (p *pool) SetEACLParam(ctx context.Context, table *eacl.Table, callParam *CallParam) error {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return err
@@ -614,7 +614,7 @@ func (p *pool) SetEACLParam(ctx context.Context, table *eacl.Table, callParam Ca
 	return err
 }
 
-func (p *pool) AnnounceContainerUsedSpaceParam(ctx context.Context, announce []container.UsedSpaceAnnouncement, callParam CallParam) error {
+func (p *pool) AnnounceContainerUsedSpaceParam(ctx context.Context, announce []container.UsedSpaceAnnouncement, callParam *CallParam) error {
 	cp, options, err := p.connParam(ctx, callParam)
 	if err != nil {
 		return err
diff --git a/pool/pool_test.go b/pool/pool_test.go
index f9ff43f1..f2d2155e 100644
--- a/pool/pool_test.go
+++ b/pool/pool_test.go
@@ -308,7 +308,114 @@ func TestTwoFailed(t *testing.T) {
 }
 
 func TestSessionCache(t *testing.T) {
+	ctrl := gomock.NewController(t)
+	defer ctrl.Finish()
 
+	var tokens []*session.Token
+	clientBuilder := func(opts ...client.Option) (client.Client, error) {
+		mockClient := NewMockClient(ctrl)
+		mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
+			tok := session.NewToken()
+			uid, err := uuid.New().MarshalBinary()
+			require.NoError(t, err)
+			tok.SetID(uid)
+			tokens = append(tokens, tok)
+			return tok, err
+		}).MaxTimes(2)
+
+		mockClient.EXPECT().GetObject(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("session token does not exist"))
+		mockClient.EXPECT().PutObject(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil)
+
+		return mockClient, nil
+	}
+
+	key, err := keys.NewPrivateKey()
+	require.NoError(t, err)
+
+	pb := new(Builder)
+	pb.AddNode("peer0", 1)
+
+	opts := &BuilderOptions{
+		Key:           &key.PrivateKey,
+		clientBuilder: clientBuilder,
+	}
+
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	pool, err := pb.Build(ctx, opts)
+	require.NoError(t, err)
+
+	// cache must contain session token
+	_, st, err := pool.Connection()
+	require.NoError(t, err)
+	require.Contains(t, tokens, st)
+
+	_, err = pool.GetObjectParam(ctx, nil, &CallParam{})
+	require.Error(t, err)
+
+	// cache must not contain session token
+	_, st, err = pool.Connection()
+	require.NoError(t, err)
+	require.Nil(t, st)
+
+	_, err = pool.PutObjectParam(ctx, nil, &CallParam{})
+	require.NoError(t, err)
+
+	// cache must contain session token
+	_, st, err = pool.Connection()
+	require.NoError(t, err)
+	require.Contains(t, tokens, st)
+}
+
+func TestSessionCacheWithKey(t *testing.T) {
+	ctrl := gomock.NewController(t)
+	defer ctrl.Finish()
+
+	var tokens []*session.Token
+	clientBuilder := func(opts ...client.Option) (client.Client, error) {
+		mockClient := NewMockClient(ctrl)
+		mockClient.EXPECT().CreateSession(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}, _ ...interface{}) (*session.Token, error) {
+			tok := session.NewToken()
+			uid, err := uuid.New().MarshalBinary()
+			require.NoError(t, err)
+			tok.SetID(uid)
+			tokens = append(tokens, tok)
+			return tok, err
+		}).MaxTimes(2)
+
+		mockClient.EXPECT().GetObject(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil)
+
+		return mockClient, nil
+	}
+
+	key, err := keys.NewPrivateKey()
+	require.NoError(t, err)
+	key2, err := keys.NewPrivateKey()
+	require.NoError(t, err)
+
+	pb := new(Builder)
+	pb.AddNode("peer0", 1)
+
+	opts := &BuilderOptions{
+		Key:           &key.PrivateKey,
+		clientBuilder: clientBuilder,
+	}
+
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	pool, err := pb.Build(ctx, opts)
+	require.NoError(t, err)
+
+	// cache must contain session token
+	_, st, err := pool.Connection()
+	require.NoError(t, err)
+	require.Contains(t, tokens, st)
+
+	_, err = pool.GetObjectParam(ctx, nil, &CallParam{Key: &key2.PrivateKey})
+	require.NoError(t, err)
+	require.Len(t, tokens, 2)
 }
 
 func newToken(t *testing.T) *session.Token {