diff --git a/pool/pool.go b/pool/pool.go index 472fe3ec..ac10e4fe 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -259,8 +259,9 @@ type innerPool struct { } const ( - DefaultSessionTokenExpirationDuration = 100 // in blocks - DefaultSessionTokenThreshold = 5 * time.Second + defaultSessionTokenExpirationDuration = 100 // in blocks + + defaultSessionTokenThreshold = 5 * time.Second ) func newPool(ctx context.Context, options *BuilderOptions) (Pool, error) { @@ -270,11 +271,11 @@ func newPool(ctx context.Context, options *BuilderOptions) (Pool, error) { } if options.SessionExpirationDuration == 0 { - options.SessionExpirationDuration = DefaultSessionTokenExpirationDuration + options.SessionExpirationDuration = defaultSessionTokenExpirationDuration } if options.SessionTokenThreshold <= 0 { - options.SessionTokenThreshold = DefaultSessionTokenThreshold + options.SessionTokenThreshold = defaultSessionTokenThreshold } ownerID := owner.NewIDFromPublicKey(&options.Key.PublicKey) @@ -588,8 +589,8 @@ func (p *pool) removeSessionTokenAfterThreshold(cfg *callConfig) error { } type callContext struct { - // context for RPC - ctxBase context.Context + // base context for RPC + context.Context client Client @@ -601,53 +602,98 @@ type callContext struct { // flag to open default session if session token is missing sessionDefault bool - sessionToken *session.Token + sessionTarget func(session.Token) } -func (p *pool) prepareCallContext(ctx *callContext, cfg *callConfig) error { +func (p *pool) initCallContext(ctx *callContext, cfg *callConfig) error { cp, err := p.connection() if err != nil { return err } - ctx.endpoint = cp.address - ctx.client = cp.client - ctx.key = cfg.key if ctx.key == nil { + // use pool key if caller didn't specify its own ctx.key = p.key } - ctx.sessionDefault = cfg.useDefaultSession - ctx.sessionToken = cfg.stoken + ctx.endpoint = cp.address + ctx.client = cp.client - if ctx.sessionToken == nil && ctx.sessionDefault { - cacheKey := formCacheKey(ctx.endpoint, ctx.key) - - ctx.sessionToken = p.cache.Get(cacheKey) - if ctx.sessionToken == nil { - var cliPrm client.CreateSessionPrm - - cliPrm.SetExp(math.MaxUint32) - - cliRes, err := ctx.client.CreateSession(ctx.ctxBase, cliPrm) - if err != nil { - return fmt.Errorf("default session: %w", err) - } - - ctx.sessionToken = sessionTokenForOwner(owner.NewIDFromPublicKey(&ctx.key.PublicKey), cliRes) - - _ = p.cache.Put(cacheKey, ctx.sessionToken) - } + if ctx.sessionTarget == nil && cfg.stoken != nil { + ctx.sessionTarget(*cfg.stoken) } - if ctx.sessionToken != nil && ctx.sessionToken.Signature() == nil { - err = ctx.sessionToken.Sign(ctx.key) - } + // note that we don't override session provided by the caller + ctx.sessionDefault = cfg.stoken == nil && cfg.useDefaultSession return err } +type callContextWithRetry struct { + callContext + + noRetry bool +} + +func (p *pool) initCallContextWithRetry(ctx *callContextWithRetry, cfg *callConfig) error { + err := p.initCallContext(&ctx.callContext, cfg) + if err != nil { + return err + } + + // don't retry if session was specified by the caller + ctx.noRetry = cfg.stoken != nil + + return nil +} + +func (p *pool) openDefaultSession(ctx *callContext) error { + cacheKey := formCacheKey(ctx.endpoint, ctx.key) + + tok := p.cache.Get(cacheKey) + if tok != nil { + // use cached token + ctx.sessionTarget(*tok) + return nil + } + + // open new session + var cliPrm client.CreateSessionPrm + + cliPrm.SetExp(math.MaxUint32) + + cliRes, err := ctx.client.CreateSession(ctx, cliPrm) + if err != nil { + return fmt.Errorf("session API client: %w", err) + } + + tok = sessionTokenForOwner(owner.NewIDFromPublicKey(&ctx.key.PublicKey), cliRes) + + // sign the token + err = tok.Sign(ctx.key) + if err != nil { + return fmt.Errorf("sign token of the opened session: %w", err) + } + + // cache the opened session + p.cache.Put(cacheKey, tok) + + ctx.sessionTarget(*tok) + + return nil +} + +func (p *pool) handleAttemptError(ctx *callContextWithRetry, err error) bool { + isTokenErr := p.checkSessionTokenErr(err, ctx.endpoint) + // note that checkSessionTokenErr must be called + res := isTokenErr && !ctx.noRetry + + ctx.noRetry = true + + return res +} + func (p *pool) PutObject(ctx context.Context, hdr object.Object, payload io.Reader, opts ...CallOption) (*oid.ID, error) { cfg := cfgFromOpts(append(opts, useDefaultSession())...) @@ -668,11 +714,11 @@ func (p *pool) PutObject(ctx context.Context, hdr object.Object, payload io.Read var ctxCall callContext - ctxCall.ctxBase = ctx + ctxCall.Context = ctx - err = p.prepareCallContext(&ctxCall, cfg) + err = p.initCallContext(&ctxCall, cfg) if err != nil { - return nil, err + return nil, fmt.Errorf("init call context") } var prm client.PrmObjectPutInit @@ -682,12 +728,16 @@ func (p *pool) PutObject(ctx context.Context, hdr object.Object, payload io.Read return nil, fmt.Errorf("init writing on API client: %w", err) } - wObj.UseKey(*ctxCall.key) - - if ctxCall.sessionToken != nil { - wObj.WithinSession(*ctxCall.sessionToken) + if ctxCall.sessionDefault { + ctxCall.sessionTarget = wObj.WithinSession + err = p.openDefaultSession(&ctxCall) + if err != nil { + return nil, fmt.Errorf("open default session: %w", err) + } } + wObj.UseKey(*ctxCall.key) + if cfg.btoken != nil { wObj.WithBearerToken(*cfg.btoken) } @@ -786,20 +836,42 @@ type ResGetObject struct { Payload io.ReadCloser } +func (p *pool) callWithRetry(ctx *callContextWithRetry, f func() error) error { + var err error + + if ctx.sessionDefault { + err = p.openDefaultSession(&ctx.callContext) + if err != nil { + return fmt.Errorf("open default session: %w", err) + } + } + + err = f() + + if p.checkSessionTokenErr(err, ctx.endpoint) && !ctx.noRetry { + // don't retry anymore + ctx.noRetry = true + return p.callWithRetry(ctx, f) + } + + return err +} + func (p *pool) GetObject(ctx context.Context, addr address.Address, opts ...CallOption) (*ResGetObject, error) { cfg := cfgFromOpts(append(opts, useDefaultSession())...) - var ctxCall callContext + var prm client.PrmObjectGet - ctxCall.ctxBase = ctx + var cc callContextWithRetry - err := p.prepareCallContext(&ctxCall, cfg) + cc.Context = ctx + cc.sessionTarget = prm.WithinSession + + err := p.initCallContextWithRetry(&cc, cfg) if err != nil { return nil, err } - var prm client.PrmObjectGet - if cnr := addr.ContainerID(); cnr != nil { prm.FromContainer(*cnr) } @@ -808,33 +880,25 @@ func (p *pool) GetObject(ctx context.Context, addr address.Address, opts ...Call prm.ByID(*obj) } - if ctxCall.sessionToken != nil { - prm.WithinSession(*ctxCall.sessionToken) - } - - if cfg.btoken != nil { - prm.WithBearerToken(*cfg.btoken) - } - - rObj, err := ctxCall.client.ObjectGetInit(ctx, prm) - if err != nil { - return nil, err - } - - rObj.UseKey(*ctxCall.key) - var res ResGetObject - if !rObj.ReadHeader(&res.Header) { - _, err = rObj.Close() - if p.checkSessionTokenErr(err, ctxCall.endpoint) && !cfg.isRetry { - return p.GetObject(ctx, addr, append(opts, retry())...) + err = p.callWithRetry(&cc, func() error { + rObj, err := cc.client.ObjectGetInit(ctx, prm) + if err != nil { + return fmt.Errorf("init object reading on client: %w", err) } - return nil, fmt.Errorf("read header: %w", err) - } + rObj.UseKey(*cc.key) - res.Payload = (*objectReadCloser)(rObj) + if !rObj.ReadHeader(&res.Header) { + _, err = rObj.Close() + return fmt.Errorf("read header: %w", err) + } + + res.Payload = (*objectReadCloser)(rObj) + + return nil + }) return &res, nil }