diff --git a/pool/client.go b/pool/client.go new file mode 100644 index 0000000..f1abbf2 --- /dev/null +++ b/pool/client.go @@ -0,0 +1,1283 @@ +package pool + +import ( + "bytes" + "context" + "crypto/ecdsa" + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/accounting" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/ape" + sdkClient "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client" + apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container" + cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/session" + "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "google.golang.org/grpc" +) + +// errPoolClientUnhealthy is an error to indicate that client in pool is unhealthy. +var errPoolClientUnhealthy = errors.New("pool client unhealthy") + +// clientStatusMonitor count error rate and other statistics for connection. +type clientStatusMonitor struct { + logger *zap.Logger + addr string + healthy *atomic.Uint32 + errorThreshold uint32 + + mu sync.RWMutex // protect counters + currentErrorCount uint32 + overallErrorCount uint64 + methods []*MethodStatus +} + +// values for healthy status of clientStatusMonitor. +const ( + // statusUnhealthyOnRequest is set when communication after dialing to the + // endpoint is failed due to immediate or accumulated errors, connection is + // available and pool should close it before re-establishing connection once again. + statusUnhealthyOnRequest = iota + + // statusHealthy is set when connection is ready to be used by the pool. + statusHealthy +) + +// MethodIndex index of method in list of statuses in clientStatusMonitor. +type MethodIndex int + +const ( + methodBalanceGet MethodIndex = iota + methodContainerPut + methodContainerGet + methodContainerList + methodContainerListStream + methodContainerDelete + methodEndpointInfo + methodNetworkInfo + methodNetMapSnapshot + methodObjectPut + methodObjectDelete + methodObjectGet + methodObjectHead + methodObjectRange + methodObjectPatch + methodSessionCreate + methodAPEManagerAddChain + methodAPEManagerRemoveChain + methodAPEManagerListChains + methodLast +) + +// String implements fmt.Stringer. +func (m MethodIndex) String() string { + switch m { + case methodBalanceGet: + return "balanceGet" + case methodContainerPut: + return "containerPut" + case methodContainerGet: + return "containerGet" + case methodContainerList: + return "containerList" + case methodContainerDelete: + return "containerDelete" + case methodEndpointInfo: + return "endpointInfo" + case methodNetworkInfo: + return "networkInfo" + case methodNetMapSnapshot: + return "netMapSnapshot" + case methodObjectPut: + return "objectPut" + case methodObjectPatch: + return "objectPatch" + case methodObjectDelete: + return "objectDelete" + case methodObjectGet: + return "objectGet" + case methodObjectHead: + return "objectHead" + case methodObjectRange: + return "objectRange" + case methodSessionCreate: + return "sessionCreate" + case methodAPEManagerAddChain: + return "apeManagerAddChain" + case methodAPEManagerRemoveChain: + return "apeManagerRemoveChain" + case methodAPEManagerListChains: + return "apeManagerListChains" + case methodLast: + return "it's a system name rather than a method" + default: + return "unknown" + } +} + +func newClientStatusMonitor(logger *zap.Logger, addr string, errorThreshold uint32) clientStatusMonitor { + methods := make([]*MethodStatus, methodLast) + for i := methodBalanceGet; i < methodLast; i++ { + methods[i] = &MethodStatus{name: i.String()} + } + + healthy := new(atomic.Uint32) + healthy.Store(statusHealthy) + + return clientStatusMonitor{ + logger: logger, + addr: addr, + healthy: healthy, + errorThreshold: errorThreshold, + methods: methods, + } +} + +// clientWrapper is used by default, alternative implementations are intended for testing purposes only. +type clientWrapper struct { + clientMutex sync.RWMutex + client *sdkClient.Client + dialed bool + prm wrapperPrm + + clientStatusMonitor +} + +// wrapperPrm is params to create clientWrapper. +type wrapperPrm struct { + logger *zap.Logger + address string + key ecdsa.PrivateKey + dialTimeout time.Duration + streamTimeout time.Duration + errorThreshold uint32 + responseInfoCallback func(sdkClient.ResponseMetaInfo) error + poolRequestInfoCallback func(RequestInfo) + dialOptions []grpc.DialOption + + gracefulCloseOnSwitchTimeout time.Duration +} + +// setAddress sets endpoint to connect in FrostFS network. +func (x *wrapperPrm) setAddress(address string) { + x.address = address +} + +// setKey sets sdkClient.Client private key to be used for the protocol communication by default. +func (x *wrapperPrm) setKey(key ecdsa.PrivateKey) { + x.key = key +} + +// setLogger sets sdkClient.Client logger. +func (x *wrapperPrm) setLogger(logger *zap.Logger) { + x.logger = logger +} + +// setDialTimeout sets the timeout for connection to be established. +func (x *wrapperPrm) setDialTimeout(timeout time.Duration) { + x.dialTimeout = timeout +} + +// setStreamTimeout sets the timeout for individual operations in streaming RPC. +func (x *wrapperPrm) setStreamTimeout(timeout time.Duration) { + x.streamTimeout = timeout +} + +// setErrorThreshold sets threshold after reaching which connection is considered unhealthy +// until Pool.startRebalance routing updates its status. +func (x *wrapperPrm) setErrorThreshold(threshold uint32) { + x.errorThreshold = threshold +} + +// setGracefulCloseOnSwitchTimeout specifies the timeout after which unhealthy client be closed during rebalancing +// if it will become healthy back. +// +// See also setErrorThreshold. +func (x *wrapperPrm) setGracefulCloseOnSwitchTimeout(timeout time.Duration) { + x.gracefulCloseOnSwitchTimeout = timeout +} + +// setPoolRequestCallback sets callback that will be invoked after every pool response. +func (x *wrapperPrm) setPoolRequestCallback(f func(RequestInfo)) { + x.poolRequestInfoCallback = f +} + +// setResponseInfoCallback sets callback that will be invoked after every response. +func (x *wrapperPrm) setResponseInfoCallback(f func(sdkClient.ResponseMetaInfo) error) { + x.responseInfoCallback = f +} + +// setGRPCDialOptions sets the gRPC dial options for new gRPC client connection. +func (x *wrapperPrm) setGRPCDialOptions(opts []grpc.DialOption) { + x.dialOptions = opts +} + +// newWrapper creates a clientWrapper that implements the client interface. +func newWrapper(prm wrapperPrm) *clientWrapper { + var cl sdkClient.Client + prmInit := sdkClient.PrmInit{ + Key: prm.key, + ResponseInfoCallback: prm.responseInfoCallback, + } + + cl.Init(prmInit) + + res := &clientWrapper{ + client: &cl, + clientStatusMonitor: newClientStatusMonitor(prm.logger, prm.address, prm.errorThreshold), + prm: prm, + } + + return res +} + +// dial establishes a connection to the server from the FrostFS network. +// Returns an error describing failure reason. If failed, the client +// SHOULD NOT be used. +func (c *clientWrapper) dial(ctx context.Context) error { + cl, err := c.getClient() + if err != nil { + return err + } + + prmDial := sdkClient.PrmDial{ + Endpoint: c.prm.address, + DialTimeout: c.prm.dialTimeout, + StreamTimeout: c.prm.streamTimeout, + GRPCDialOptions: c.prm.dialOptions, + } + + err = cl.Dial(ctx, prmDial) + c.setDialed(err == nil) + if err != nil { + return err + } + + return nil +} + +// restart recreates and redial inner sdk client. +func (c *clientWrapper) restart(ctx context.Context) error { + var cl sdkClient.Client + prmInit := sdkClient.PrmInit{ + Key: c.prm.key, + ResponseInfoCallback: c.prm.responseInfoCallback, + } + + cl.Init(prmInit) + + prmDial := sdkClient.PrmDial{ + Endpoint: c.prm.address, + DialTimeout: c.prm.dialTimeout, + StreamTimeout: c.prm.streamTimeout, + GRPCDialOptions: c.prm.dialOptions, + } + + // if connection is dialed before, to avoid routine / connection leak, + // pool has to close it and then initialize once again. + if c.isDialed() { + c.scheduleGracefulClose() + } + + err := cl.Dial(ctx, prmDial) + c.setDialed(err == nil) + if err != nil { + return err + } + + c.clientMutex.Lock() + c.client = &cl + c.clientMutex.Unlock() + + return nil +} + +func (c *clientWrapper) isDialed() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.dialed +} + +func (c *clientWrapper) setDialed(dialed bool) { + c.mu.Lock() + c.dialed = dialed + c.mu.Unlock() +} + +func (c *clientWrapper) getClient() (*sdkClient.Client, error) { + c.clientMutex.RLock() + defer c.clientMutex.RUnlock() + if c.isHealthy() { + return c.client, nil + } + return nil, errPoolClientUnhealthy +} + +func (c *clientWrapper) getClientRaw() *sdkClient.Client { + c.clientMutex.RLock() + defer c.clientMutex.RUnlock() + return c.client +} + +// balanceGet invokes sdkClient.BalanceGet parse response status to error and return result as is. +func (c *clientWrapper) balanceGet(ctx context.Context, prm PrmBalanceGet) (accounting.Decimal, error) { + cl, err := c.getClient() + if err != nil { + return accounting.Decimal{}, err + } + + cliPrm := sdkClient.PrmBalanceGet{ + Account: prm.account, + } + + start := time.Now() + res, err := cl.BalanceGet(ctx, cliPrm) + c.incRequests(time.Since(start), methodBalanceGet) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return accounting.Decimal{}, fmt.Errorf("balance get on client: %w", err) + } + + return res.Amount(), nil +} + +// containerPut invokes sdkClient.ContainerPut parse response status to error and return result as is. +// It also waits for the container to appear on the network. +func (c *clientWrapper) containerPut(ctx context.Context, prm PrmContainerPut) (cid.ID, error) { + cl, err := c.getClient() + if err != nil { + return cid.ID{}, err + } + + start := time.Now() + res, err := cl.ContainerPut(ctx, prm.ClientParams) + c.incRequests(time.Since(start), methodContainerPut) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return cid.ID{}, fmt.Errorf("container put on client: %w", err) + } + + if prm.WaitParams == nil { + prm.WaitParams = defaultWaitParams() + } + if err = prm.WaitParams.CheckValidity(); err != nil { + return cid.ID{}, fmt.Errorf("invalid wait parameters: %w", err) + } + + idCnr := res.ID() + + getPrm := PrmContainerGet{ + ContainerID: idCnr, + Session: prm.ClientParams.Session, + } + + err = waitForContainerPresence(ctx, c, getPrm, prm.WaitParams) + if err = c.handleError(ctx, nil, err); err != nil { + return cid.ID{}, fmt.Errorf("wait container presence on client: %w", err) + } + + return idCnr, nil +} + +// containerGet invokes sdkClient.ContainerGet parse response status to error and return result as is. +func (c *clientWrapper) containerGet(ctx context.Context, prm PrmContainerGet) (container.Container, error) { + cl, err := c.getClient() + if err != nil { + return container.Container{}, err + } + + cliPrm := sdkClient.PrmContainerGet{ + ContainerID: &prm.ContainerID, + Session: prm.Session, + } + + start := time.Now() + res, err := cl.ContainerGet(ctx, cliPrm) + c.incRequests(time.Since(start), methodContainerGet) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return container.Container{}, fmt.Errorf("container get on client: %w", err) + } + + return res.Container(), nil +} + +// containerList invokes sdkClient.ContainerList parse response status to error and return result as is. +func (c *clientWrapper) containerList(ctx context.Context, prm PrmContainerList) ([]cid.ID, error) { + cl, err := c.getClient() + if err != nil { + return nil, err + } + + cliPrm := sdkClient.PrmContainerList{ + OwnerID: prm.OwnerID, + Session: prm.Session, + } + + start := time.Now() + res, err := cl.ContainerList(ctx, cliPrm) + c.incRequests(time.Since(start), methodContainerList) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return nil, fmt.Errorf("container list on client: %w", err) + } + return res.Containers(), nil +} + +// PrmListStream groups parameters of ListContainersStream operation. +type PrmListStream struct { + OwnerID user.ID + + Session *session.Container +} + +// ResListStream is designed to read list of object identifiers from FrostFS system. +// +// Must be initialized using Pool.ListContainersStream, any other usage is unsafe. +type ResListStream struct { + r *sdkClient.ContainerListReader + handleError func(context.Context, apistatus.Status, error) error +} + +// Read reads another list of the container identifiers. +func (x *ResListStream) Read(buf []cid.ID) (int, error) { + n, ok := x.r.Read(buf) + if !ok { + res, err := x.r.Close() + if err == nil { + return n, io.EOF + } + + var status apistatus.Status + if res != nil { + status = res.Status() + } + err = x.handleError(nil, status, err) + + return n, err + } + + return n, nil +} + +// Iterate iterates over the list of found container identifiers. +// f can return true to stop iteration earlier. +// +// Returns an error if container can't be read. +func (x *ResListStream) Iterate(f func(cid.ID) bool) error { + return x.r.Iterate(f) +} + +// Close ends reading list of the matched containers and returns the result of the operation +// along with the final results. Must be called after using the ResListStream. +func (x *ResListStream) Close() { + _, _ = x.r.Close() +} + +// containerList invokes sdkClient.ContainerList parse response status to error and return result as is. +func (c *clientWrapper) containerListStream(ctx context.Context, prm PrmListStream) (ResListStream, error) { + cl, err := c.getClient() + if err != nil { + return ResListStream{}, err + } + + cliPrm := sdkClient.PrmContainerListStream{ + OwnerID: prm.OwnerID, + Session: prm.Session, + } + + res, err := cl.ContainerListInit(ctx, cliPrm) + if err = c.handleError(ctx, nil, err); err != nil { + return ResListStream{}, fmt.Errorf("init container listing on client: %w", err) + } + return ResListStream{r: res, handleError: c.handleError}, nil +} + +// containerDelete invokes sdkClient.ContainerDelete parse response status to error. +// It also waits for the container to be removed from the network. +func (c *clientWrapper) containerDelete(ctx context.Context, prm PrmContainerDelete) error { + cl, err := c.getClient() + if err != nil { + return err + } + + cliPrm := sdkClient.PrmContainerDelete{ + ContainerID: &prm.ContainerID, + Session: prm.Session, + } + + start := time.Now() + res, err := cl.ContainerDelete(ctx, cliPrm) + c.incRequests(time.Since(start), methodContainerDelete) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return fmt.Errorf("container delete on client: %w", err) + } + + if prm.WaitParams == nil { + prm.WaitParams = defaultWaitParams() + } + if err := prm.WaitParams.CheckValidity(); err != nil { + return fmt.Errorf("invalid wait parameters: %w", err) + } + + getPrm := PrmContainerGet{ + ContainerID: prm.ContainerID, + Session: prm.Session, + } + + return waitForContainerRemoved(ctx, c, getPrm, prm.WaitParams) +} + +// apeManagerAddChain invokes sdkClient.APEManagerAddChain and parse response status to error. +func (c *clientWrapper) apeManagerAddChain(ctx context.Context, prm PrmAddAPEChain) error { + cl, err := c.getClient() + if err != nil { + return err + } + + cliPrm := sdkClient.PrmAPEManagerAddChain{ + ChainTarget: prm.Target, + Chain: prm.Chain, + } + + start := time.Now() + res, err := cl.APEManagerAddChain(ctx, cliPrm) + c.incRequests(time.Since(start), methodAPEManagerAddChain) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return fmt.Errorf("add chain error: %w", err) + } + + return nil +} + +// apeManagerRemoveChain invokes sdkClient.APEManagerRemoveChain and parse response status to error. +func (c *clientWrapper) apeManagerRemoveChain(ctx context.Context, prm PrmRemoveAPEChain) error { + cl, err := c.getClient() + if err != nil { + return err + } + + cliPrm := sdkClient.PrmAPEManagerRemoveChain{ + ChainTarget: prm.Target, + ChainID: prm.ChainID, + } + + start := time.Now() + res, err := cl.APEManagerRemoveChain(ctx, cliPrm) + c.incRequests(time.Since(start), methodAPEManagerRemoveChain) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return fmt.Errorf("remove chain error: %w", err) + } + + return nil +} + +// apeManagerListChains invokes sdkClient.APEManagerListChains. Returns chains and parsed response status to error. +func (c *clientWrapper) apeManagerListChains(ctx context.Context, prm PrmListAPEChains) ([]ape.Chain, error) { + cl, err := c.getClient() + if err != nil { + return nil, err + } + + cliPrm := sdkClient.PrmAPEManagerListChains{ + ChainTarget: prm.Target, + } + + start := time.Now() + res, err := cl.APEManagerListChains(ctx, cliPrm) + c.incRequests(time.Since(start), methodAPEManagerListChains) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return nil, fmt.Errorf("list chains error: %w", err) + } + + return res.Chains, nil +} + +// endpointInfo invokes sdkClient.EndpointInfo parse response status to error and return result as is. +func (c *clientWrapper) endpointInfo(ctx context.Context, _ prmEndpointInfo) (netmap.NodeInfo, error) { + cl, err := c.getClient() + if err != nil { + return netmap.NodeInfo{}, err + } + + return c.endpointInfoRaw(ctx, cl) +} + +func (c *clientWrapper) healthcheck(ctx context.Context) (netmap.NodeInfo, error) { + cl := c.getClientRaw() + return c.endpointInfoRaw(ctx, cl) +} + +func (c *clientWrapper) endpointInfoRaw(ctx context.Context, cl *sdkClient.Client) (netmap.NodeInfo, error) { + start := time.Now() + res, err := cl.EndpointInfo(ctx, sdkClient.PrmEndpointInfo{}) + c.incRequests(time.Since(start), methodEndpointInfo) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return netmap.NodeInfo{}, fmt.Errorf("endpoint info on client: %w", err) + } + + return res.NodeInfo(), nil +} + +// networkInfo invokes sdkClient.NetworkInfo parse response status to error and return result as is. +func (c *clientWrapper) networkInfo(ctx context.Context, _ prmNetworkInfo) (netmap.NetworkInfo, error) { + cl, err := c.getClient() + if err != nil { + return netmap.NetworkInfo{}, err + } + + start := time.Now() + res, err := cl.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{}) + c.incRequests(time.Since(start), methodNetworkInfo) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return netmap.NetworkInfo{}, fmt.Errorf("network info on client: %w", err) + } + + return res.Info(), nil +} + +// networkInfo invokes sdkClient.NetworkInfo parse response status to error and return result as is. +func (c *clientWrapper) netMapSnapshot(ctx context.Context, _ prmNetMapSnapshot) (netmap.NetMap, error) { + cl, err := c.getClient() + if err != nil { + return netmap.NetMap{}, err + } + + start := time.Now() + res, err := cl.NetMapSnapshot(ctx, sdkClient.PrmNetMapSnapshot{}) + c.incRequests(time.Since(start), methodNetMapSnapshot) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return netmap.NetMap{}, fmt.Errorf("network map snapshot on client: %w", err) + } + + return res.NetMap(), nil +} + +// objectPatch patches object in FrostFS. +func (c *clientWrapper) objectPatch(ctx context.Context, prm PrmObjectPatch) (ResPatchObject, error) { + cl, err := c.getClient() + if err != nil { + return ResPatchObject{}, err + } + + start := time.Now() + pObj, err := cl.ObjectPatchInit(ctx, sdkClient.PrmObjectPatch{ + Address: prm.addr, + Session: prm.stoken, + Key: prm.key, + BearerToken: prm.btoken, + MaxChunkLength: prm.maxPayloadPatchChunkLength, + }) + if err = c.handleError(ctx, nil, err); err != nil { + return ResPatchObject{}, fmt.Errorf("init patching on API client: %w", err) + } + c.incRequests(time.Since(start), methodObjectPatch) + + start = time.Now() + attrPatchSuccess := pObj.PatchAttributes(ctx, prm.newAttrs, prm.replaceAttrs) + c.incRequests(time.Since(start), methodObjectPatch) + + if attrPatchSuccess { + start = time.Now() + _ = pObj.PatchPayload(ctx, prm.rng, prm.payload) + c.incRequests(time.Since(start), methodObjectPatch) + } + + res, err := pObj.Close(ctx) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return ResPatchObject{}, fmt.Errorf("client failure: %w", err) + } + + return ResPatchObject{ObjectID: res.ObjectID()}, nil +} + +// objectPut writes object to FrostFS. +func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (ResPutObject, error) { + if prm.bufferMaxSize == 0 { + prm.bufferMaxSize = defaultBufferMaxSizeForPut + } + + if prm.clientCut { + return c.objectPutClientCut(ctx, prm) + } + + return c.objectPutServerCut(ctx, prm) +} + +func (c *clientWrapper) objectPutServerCut(ctx context.Context, prm PrmObjectPut) (ResPutObject, error) { + cl, err := c.getClient() + if err != nil { + return ResPutObject{}, err + } + + cliPrm := sdkClient.PrmObjectPutInit{ + CopiesNumber: prm.copiesNumber, + Session: prm.stoken, + Key: prm.key, + BearerToken: prm.btoken, + } + + start := time.Now() + wObj, err := cl.ObjectPutInit(ctx, cliPrm) + c.incRequests(time.Since(start), methodObjectPut) + if err = c.handleError(ctx, nil, err); err != nil { + return ResPutObject{}, fmt.Errorf("init writing on API client: %w", err) + } + + if wObj.WriteHeader(ctx, prm.hdr) { + sz := prm.hdr.PayloadSize() + + if data := prm.hdr.Payload(); len(data) > 0 { + if prm.payload != nil { + prm.payload = io.MultiReader(bytes.NewReader(data), prm.payload) + } else { + prm.payload = bytes.NewReader(data) + sz = uint64(len(data)) + } + } + + if prm.payload != nil { + if sz == 0 || sz > prm.bufferMaxSize { + sz = prm.bufferMaxSize + } + + buf := make([]byte, sz) + + var n int + + for { + n, err = prm.payload.Read(buf) + if n > 0 { + start = time.Now() + successWrite := wObj.WritePayloadChunk(ctx, buf[:n]) + c.incRequests(time.Since(start), methodObjectPut) + if !successWrite { + break + } + + continue + } + + if errors.Is(err, io.EOF) { + break + } + + return ResPutObject{}, fmt.Errorf("read payload: %w", c.handleError(ctx, nil, err)) + } + } + } + + res, err := wObj.Close(ctx) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { // here err already carries both status and client errors + return ResPutObject{}, fmt.Errorf("client failure: %w", err) + } + + return ResPutObject{ + ObjectID: res.StoredObjectID(), + Epoch: res.StoredEpoch(), + }, nil +} + +func (c *clientWrapper) objectPutClientCut(ctx context.Context, prm PrmObjectPut) (ResPutObject, error) { + putInitPrm := PrmObjectPutClientCutInit{ + PrmObjectPut: prm, + } + + start := time.Now() + wObj, err := c.objectPutInitTransformer(putInitPrm) + c.incRequests(time.Since(start), methodObjectPut) + if err = c.handleError(ctx, nil, err); err != nil { + return ResPutObject{}, fmt.Errorf("init writing on API client: %w", err) + } + + if wObj.WriteHeader(ctx, prm.hdr) { + sz := prm.hdr.PayloadSize() + + if data := prm.hdr.Payload(); len(data) > 0 { + if prm.payload != nil { + prm.payload = io.MultiReader(bytes.NewReader(data), prm.payload) + } else { + prm.payload = bytes.NewReader(data) + sz = uint64(len(data)) + } + } + + if prm.payload != nil { + if sz == 0 || sz > prm.bufferMaxSize { + sz = prm.bufferMaxSize + } + + buf := make([]byte, sz) + + var n int + + for { + n, err = prm.payload.Read(buf) + if n > 0 { + start = time.Now() + successWrite := wObj.WritePayloadChunk(ctx, buf[:n]) + c.incRequests(time.Since(start), methodObjectPut) + if !successWrite { + break + } + + continue + } + + if errors.Is(err, io.EOF) { + break + } + + return ResPutObject{}, fmt.Errorf("read payload: %w", c.handleError(ctx, nil, err)) + } + } + } + + res, err := wObj.Close(ctx) + var st apistatus.Status + if res != nil { + st = res.Status + } + if err = c.handleError(ctx, st, err); err != nil { // here err already carries both status and client errors + return ResPutObject{}, fmt.Errorf("client failure: %w", err) + } + + return ResPutObject{ + ObjectID: res.OID, + Epoch: res.Epoch, + }, nil +} + +// objectDelete invokes sdkClient.ObjectDelete parse response status to error. +func (c *clientWrapper) objectDelete(ctx context.Context, prm PrmObjectDelete) error { + cl, err := c.getClient() + if err != nil { + return err + } + + cnr := prm.addr.Container() + obj := prm.addr.Object() + + cliPrm := sdkClient.PrmObjectDelete{ + BearerToken: prm.btoken, + Session: prm.stoken, + ContainerID: &cnr, + ObjectID: &obj, + Key: prm.key, + } + + start := time.Now() + res, err := cl.ObjectDelete(ctx, cliPrm) + c.incRequests(time.Since(start), methodObjectDelete) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return fmt.Errorf("delete object on client: %w", err) + } + return nil +} + +// objectGet returns reader for object. +func (c *clientWrapper) objectGet(ctx context.Context, prm PrmObjectGet) (ResGetObject, error) { + cl, err := c.getClient() + if err != nil { + return ResGetObject{}, err + } + + prmCnr := prm.addr.Container() + prmObj := prm.addr.Object() + + cliPrm := sdkClient.PrmObjectGet{ + BearerToken: prm.btoken, + Session: prm.stoken, + ContainerID: &prmCnr, + ObjectID: &prmObj, + Key: prm.key, + } + + var res ResGetObject + + rObj, err := cl.ObjectGetInit(ctx, cliPrm) + if err = c.handleError(ctx, nil, err); err != nil { + return ResGetObject{}, fmt.Errorf("init object reading on client: %w", err) + } + + start := time.Now() + successReadHeader := rObj.ReadHeader(&res.Header) + c.incRequests(time.Since(start), methodObjectGet) + if !successReadHeader { + rObjRes, err := rObj.Close() + var st apistatus.Status + if rObjRes != nil { + st = rObjRes.Status() + } + err = c.handleError(ctx, st, err) + return res, fmt.Errorf("read header: %w", err) + } + + res.Payload = &objectReadCloser{ + reader: rObj, + elapsedTimeCallback: func(elapsed time.Duration) { + c.incRequests(elapsed, methodObjectGet) + }, + } + + return res, nil +} + +// objectHead invokes sdkClient.ObjectHead parse response status to error and return result as is. +func (c *clientWrapper) objectHead(ctx context.Context, prm PrmObjectHead) (object.Object, error) { + cl, err := c.getClient() + if err != nil { + return object.Object{}, err + } + + prmCnr := prm.addr.Container() + prmObj := prm.addr.Object() + + cliPrm := sdkClient.PrmObjectHead{ + BearerToken: prm.btoken, + Session: prm.stoken, + Raw: prm.raw, + ContainerID: &prmCnr, + ObjectID: &prmObj, + Key: prm.key, + } + + var obj object.Object + + start := time.Now() + res, err := cl.ObjectHead(ctx, cliPrm) + c.incRequests(time.Since(start), methodObjectHead) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return obj, fmt.Errorf("read object header via client: %w", err) + } + if !res.ReadHeader(&obj) { + return obj, errors.New("missing object header in response") + } + + return obj, nil +} + +// objectRange returns object range reader. +func (c *clientWrapper) objectRange(ctx context.Context, prm PrmObjectRange) (ResObjectRange, error) { + cl, err := c.getClient() + if err != nil { + return ResObjectRange{}, err + } + + prmCnr := prm.addr.Container() + prmObj := prm.addr.Object() + + cliPrm := sdkClient.PrmObjectRange{ + BearerToken: prm.btoken, + Session: prm.stoken, + ContainerID: &prmCnr, + ObjectID: &prmObj, + Offset: prm.off, + Length: prm.ln, + Key: prm.key, + } + + start := time.Now() + res, err := cl.ObjectRangeInit(ctx, cliPrm) + c.incRequests(time.Since(start), methodObjectRange) + if err = c.handleError(ctx, nil, err); err != nil { + return ResObjectRange{}, fmt.Errorf("init payload range reading on client: %w", err) + } + + return ResObjectRange{ + payload: res, + elapsedTimeCallback: func(elapsed time.Duration) { + c.incRequests(elapsed, methodObjectRange) + }, + }, nil +} + +// objectSearch invokes sdkClient.ObjectSearchInit parse response status to error and return result as is. +func (c *clientWrapper) objectSearch(ctx context.Context, prm PrmObjectSearch) (ResObjectSearch, error) { + cl, err := c.getClient() + if err != nil { + return ResObjectSearch{}, err + } + + cliPrm := sdkClient.PrmObjectSearch{ + ContainerID: &prm.cnrID, + Filters: prm.filters, + Session: prm.stoken, + BearerToken: prm.btoken, + Key: prm.key, + } + + res, err := cl.ObjectSearchInit(ctx, cliPrm) + if err = c.handleError(ctx, nil, err); err != nil { + return ResObjectSearch{}, fmt.Errorf("init object searching on client: %w", err) + } + + return ResObjectSearch{r: res, handleError: c.handleError}, nil +} + +// sessionCreate invokes sdkClient.SessionCreate parse response status to error and return result as is. +func (c *clientWrapper) sessionCreate(ctx context.Context, prm prmCreateSession) (resCreateSession, error) { + cl, err := c.getClient() + if err != nil { + return resCreateSession{}, err + } + + cliPrm := sdkClient.PrmSessionCreate{ + Expiration: prm.exp, + Key: &prm.key, + } + + start := time.Now() + res, err := cl.SessionCreate(ctx, cliPrm) + c.incRequests(time.Since(start), methodSessionCreate) + var st apistatus.Status + if res != nil { + st = res.Status() + } + if err = c.handleError(ctx, st, err); err != nil { + return resCreateSession{}, fmt.Errorf("session creation on client: %w", err) + } + + return resCreateSession{ + id: res.ID(), + sessionKey: res.PublicKey(), + }, nil +} + +func (c *clientStatusMonitor) isHealthy() bool { + return c.healthy.Load() == statusHealthy +} + +func (c *clientStatusMonitor) setHealthy() { + c.healthy.Store(statusHealthy) +} + +func (c *clientStatusMonitor) setUnhealthy() { + c.healthy.Store(statusUnhealthyOnRequest) +} + +func (c *clientStatusMonitor) address() string { + return c.addr +} + +func (c *clientStatusMonitor) incErrorRate() { + c.mu.Lock() + c.currentErrorCount++ + c.overallErrorCount++ + + thresholdReached := c.currentErrorCount >= c.errorThreshold + if thresholdReached { + c.setUnhealthy() + c.currentErrorCount = 0 + } + c.mu.Unlock() + + if thresholdReached { + c.log(zapcore.WarnLevel, "error threshold reached", + zap.String("address", c.addr), zap.Uint32("threshold", c.errorThreshold)) + } +} + +func (c *clientStatusMonitor) incErrorRateToUnhealthy(err error) { + c.mu.Lock() + c.currentErrorCount = 0 + c.overallErrorCount++ + c.setUnhealthy() + c.mu.Unlock() + + c.log(zapcore.WarnLevel, "explicitly mark node unhealthy", zap.String("address", c.addr), zap.Error(err)) +} + +func (c *clientStatusMonitor) log(level zapcore.Level, msg string, fields ...zap.Field) { + if c.logger == nil { + return + } + + c.logger.Log(level, msg, fields...) +} + +func (c *clientStatusMonitor) currentErrorRate() uint32 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.currentErrorCount +} + +func (c *clientStatusMonitor) overallErrorRate() uint64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.overallErrorCount +} + +func (c *clientStatusMonitor) methodsStatus() []StatusSnapshot { + result := make([]StatusSnapshot, len(c.methods)) + for i, val := range c.methods { + result[i] = val.Snapshot() + } + + return result +} + +func (c *clientWrapper) incRequests(elapsed time.Duration, method MethodIndex) { + methodStat := c.methods[method] + methodStat.IncRequests(elapsed) + if c.prm.poolRequestInfoCallback != nil { + c.prm.poolRequestInfoCallback(RequestInfo{ + Address: c.prm.address, + Method: method, + Elapsed: elapsed, + }) + } +} + +func (c *clientWrapper) close() error { + if !c.isDialed() { + return nil + } + if cl := c.getClientRaw(); cl != nil { + return cl.Close() + } + return nil +} + +func (c *clientWrapper) scheduleGracefulClose() { + cl := c.getClientRaw() + if cl == nil { + return + } + + time.AfterFunc(c.prm.gracefulCloseOnSwitchTimeout, func() { + if err := cl.Close(); err != nil { + c.log(zap.DebugLevel, "close unhealthy client during rebalance", zap.String("address", c.address()), zap.Error(err)) + } + }) +} + +func (c *clientStatusMonitor) handleError(ctx context.Context, st apistatus.Status, err error) error { + if stErr := apistatus.ErrFromStatus(st); stErr != nil { + switch stErr.(type) { + case *apistatus.ServerInternal, + *apistatus.WrongMagicNumber, + *apistatus.SignatureVerification: + c.incErrorRate() + case *apistatus.NodeUnderMaintenance: + c.incErrorRateToUnhealthy(stErr) + } + + if err == nil { + err = stErr + } + + return err + } + + if err != nil { + if needCountError(ctx, err) { + if sdkClient.IsErrNodeUnderMaintenance(err) { + c.incErrorRateToUnhealthy(err) + } else { + c.incErrorRate() + } + } + + return err + } + + return nil +} + +func needCountError(ctx context.Context, err error) bool { + // non-status logic error that could be returned + // from the SDK client; should not be considered + // as a connection error + var siErr *object.SplitInfoError + if errors.As(err, &siErr) { + return false + } + var eiErr *object.ECInfoError + if errors.As(err, &eiErr) { + return false + } + + if ctx != nil && errors.Is(ctx.Err(), context.Canceled) { + return false + } + + return true +} + +// clientBuilder is a type alias of client constructors which open connection +// to the given endpoint. +type clientBuilder = func(endpoint string) client + +// RequestInfo groups info about pool request. +type RequestInfo struct { + Address string + Method MethodIndex + Elapsed time.Duration +} diff --git a/pool/connection_manager.go b/pool/connection_manager.go new file mode 100644 index 0000000..b142529 --- /dev/null +++ b/pool/connection_manager.go @@ -0,0 +1,330 @@ +package pool + +import ( + "context" + "errors" + "fmt" + "math/rand" + "sort" + "sync" + "sync/atomic" + "time" + + apistatus "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/client/status" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type innerPool struct { + lock sync.RWMutex + sampler *sampler + clients []client +} + +type connectionManager struct { + innerPools []*innerPool + rebalanceParams rebalanceParameters + clientBuilder clientBuilder + logger *zap.Logger + healthChecker *healthCheck +} + +// newConnectionManager returns an instance of connectionManager configured according to the parameters. +// +// Before using connectionManager, you MUST call Dial. +func newConnectionManager(options InitParameters) (*connectionManager, error) { + if options.key == nil { + return nil, fmt.Errorf("missed required parameter 'Key'") + } + + nodesParams, err := adjustNodeParams(options.nodeParams) + if err != nil { + return nil, err + } + + manager := &connectionManager{ + logger: options.logger, + rebalanceParams: rebalanceParameters{ + nodesParams: nodesParams, + nodeRequestTimeout: options.healthcheckTimeout, + clientRebalanceInterval: options.clientRebalanceInterval, + sessionExpirationDuration: options.sessionExpirationDuration, + }, + clientBuilder: options.clientBuilder, + } + + return manager, nil +} + +func (cm *connectionManager) dial(ctx context.Context) error { + inner := make([]*innerPool, len(cm.rebalanceParams.nodesParams)) + var atLeastOneHealthy bool + + for i, params := range cm.rebalanceParams.nodesParams { + clients := make([]client, len(params.weights)) + for j, addr := range params.addresses { + clients[j] = cm.clientBuilder(addr) + if err := clients[j].dial(ctx); err != nil { + cm.log(zap.WarnLevel, "failed to build client", zap.String("address", addr), zap.Error(err)) + continue + } + atLeastOneHealthy = true + } + source := rand.NewSource(time.Now().UnixNano()) + sampl := newSampler(params.weights, source) + + inner[i] = &innerPool{ + sampler: sampl, + clients: clients, + } + } + + if !atLeastOneHealthy { + return fmt.Errorf("at least one node must be healthy") + } + + cm.innerPools = inner + + cm.healthChecker = newHealthCheck(cm.rebalanceParams.clientRebalanceInterval) + cm.healthChecker.startRebalance(ctx, cm.rebalance) + return nil +} + +func (cm *connectionManager) rebalance(ctx context.Context) { + buffers := make([][]float64, len(cm.rebalanceParams.nodesParams)) + for i, params := range cm.rebalanceParams.nodesParams { + buffers[i] = make([]float64, len(params.weights)) + } + + cm.updateNodesHealth(ctx, buffers) +} + +func (cm *connectionManager) log(level zapcore.Level, msg string, fields ...zap.Field) { + if cm.logger == nil { + return + } + + cm.logger.Log(level, msg, fields...) +} + +func adjustNodeParams(nodeParams []NodeParam) ([]*nodesParam, error) { + if len(nodeParams) == 0 { + return nil, errors.New("no FrostFS peers configured") + } + + nodesParamsMap := make(map[int]*nodesParam) + for _, param := range nodeParams { + nodes, ok := nodesParamsMap[param.priority] + if !ok { + nodes = &nodesParam{priority: param.priority} + } + nodes.addresses = append(nodes.addresses, param.address) + nodes.weights = append(nodes.weights, param.weight) + nodesParamsMap[param.priority] = nodes + } + + nodesParams := make([]*nodesParam, 0, len(nodesParamsMap)) + for _, nodes := range nodesParamsMap { + nodes.weights = adjustWeights(nodes.weights) + nodesParams = append(nodesParams, nodes) + } + + sort.Slice(nodesParams, func(i, j int) bool { + return nodesParams[i].priority < nodesParams[j].priority + }) + + return nodesParams, nil +} + +func (cm *connectionManager) updateNodesHealth(ctx context.Context, buffers [][]float64) { + wg := sync.WaitGroup{} + for i, inner := range cm.innerPools { + wg.Add(1) + + bufferWeights := buffers[i] + go func(i int, _ *innerPool) { + defer wg.Done() + cm.updateInnerNodesHealth(ctx, i, bufferWeights) + }(i, inner) + } + wg.Wait() +} + +func (cm *connectionManager) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights []float64) { + if i > len(cm.innerPools)-1 { + return + } + pool := cm.innerPools[i] + options := cm.rebalanceParams + + healthyChanged := new(atomic.Bool) + wg := sync.WaitGroup{} + + for j, cli := range pool.clients { + wg.Add(1) + go func(j int, cli client) { + defer wg.Done() + + tctx, c := context.WithTimeout(ctx, options.nodeRequestTimeout) + defer c() + + changed, err := restartIfUnhealthy(tctx, cli) + healthy := err == nil + if healthy { + bufferWeights[j] = options.nodesParams[i].weights[j] + } else { + bufferWeights[j] = 0 + } + + if changed { + fields := []zap.Field{zap.String("address", cli.address()), zap.Bool("healthy", healthy)} + if err != nil { + fields = append(fields, zap.String("reason", err.Error())) + } + + cm.log(zap.DebugLevel, "health has changed", fields...) + healthyChanged.Store(true) + } + }(j, cli) + } + wg.Wait() + + if healthyChanged.Load() { + probabilities := adjustWeights(bufferWeights) + source := rand.NewSource(time.Now().UnixNano()) + pool.lock.Lock() + pool.sampler = newSampler(probabilities, source) + pool.lock.Unlock() + } +} + +// restartIfUnhealthy checks healthy status of client and recreate it if status is unhealthy. +// Indicating if status was changed by this function call and returns error that caused unhealthy status. +func restartIfUnhealthy(ctx context.Context, c client) (changed bool, err error) { + defer func() { + if err != nil { + c.setUnhealthy() + } else { + c.setHealthy() + } + }() + + wasHealthy := c.isHealthy() + + if res, err := c.healthcheck(ctx); err == nil { + if res.Status().IsMaintenance() { + return wasHealthy, new(apistatus.NodeUnderMaintenance) + } + + return !wasHealthy, nil + } + + if err = c.restart(ctx); err != nil { + return wasHealthy, err + } + + res, err := c.healthcheck(ctx) + if err != nil { + return wasHealthy, err + } + + if res.Status().IsMaintenance() { + return wasHealthy, new(apistatus.NodeUnderMaintenance) + } + + return !wasHealthy, nil +} + +func adjustWeights(weights []float64) []float64 { + adjusted := make([]float64, len(weights)) + sum := 0.0 + for _, weight := range weights { + sum += weight + } + if sum > 0 { + for i, weight := range weights { + adjusted[i] = weight / sum + } + } + + return adjusted +} + +func (cm *connectionManager) connection() (client, error) { + for _, inner := range cm.innerPools { + cp, err := inner.connection() + if err == nil { + return cp, nil + } + } + + return nil, errors.New("no healthy client") +} + +// iterate iterates over all clients in all innerPools. +func (cm *connectionManager) iterate(cb func(client)) { + for _, inner := range cm.innerPools { + for _, cl := range inner.clients { + if cl.isHealthy() { + cb(cl) + } + } + } +} + +func (p *innerPool) connection() (client, error) { + p.lock.RLock() // need lock because of using p.sampler + defer p.lock.RUnlock() + if len(p.clients) == 1 { + cp := p.clients[0] + if cp.isHealthy() { + return cp, nil + } + return nil, errors.New("no healthy client") + } + attempts := 3 * len(p.clients) + for range attempts { + i := p.sampler.Next() + if cp := p.clients[i]; cp.isHealthy() { + return cp, nil + } + } + + return nil, errors.New("no healthy client") +} + +func (cm connectionManager) Statistic() Statistic { + stat := Statistic{} + for _, inner := range cm.innerPools { + nodes := make([]string, 0, len(inner.clients)) + for _, cl := range inner.clients { + if cl.isHealthy() { + nodes = append(nodes, cl.address()) + } + node := NodeStatistic{ + address: cl.address(), + methods: cl.methodsStatus(), + overallErrors: cl.overallErrorRate(), + currentErrors: cl.currentErrorRate(), + } + stat.nodes = append(stat.nodes, node) + stat.overallErrors += node.overallErrors + } + if len(stat.currentNodes) == 0 { + stat.currentNodes = nodes + } + } + + return stat +} + +func (cm *connectionManager) close() { + cm.healthChecker.stopRebalance() + + // close all clients + for _, pools := range cm.innerPools { + for _, cli := range pools.clients { + _ = cli.close() + } + } +} diff --git a/pool/healthcheck.go b/pool/healthcheck.go new file mode 100644 index 0000000..2f5dec9 --- /dev/null +++ b/pool/healthcheck.go @@ -0,0 +1,47 @@ +package pool + +import ( + "context" + "time" +) + +type healthCheck struct { + cancel context.CancelFunc + closedCh chan struct{} + + clientRebalanceInterval time.Duration +} + +func newHealthCheck(clientRebalanceInterval time.Duration) *healthCheck { + var h healthCheck + h.clientRebalanceInterval = clientRebalanceInterval + h.closedCh = make(chan struct{}) + return &h +} + +// startRebalance runs loop to monitor connection healthy status. +func (h *healthCheck) startRebalance(ctx context.Context, callback func(ctx context.Context)) { + ctx, cancel := context.WithCancel(ctx) + h.cancel = cancel + + go func() { + ticker := time.NewTicker(h.clientRebalanceInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + close(h.closedCh) + return + case <-ticker.C: + callback(ctx) + ticker.Reset(h.clientRebalanceInterval) + } + } + }() +} + +func (h *healthCheck) stopRebalance() { + h.cancel() + <-h.closedCh +} diff --git a/pool/pool.go b/pool/pool.go index 1f69577..2f30ae4 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -1,17 +1,12 @@ package pool import ( - "bytes" "context" "crypto/ecdsa" "errors" "fmt" "io" "math" - "math/rand" - "sort" - "sync" - "sync/atomic" "time" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/accounting" @@ -31,7 +26,6 @@ import ( "github.com/google/uuid" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "go.uber.org/zap" - "go.uber.org/zap/zapcore" "google.golang.org/grpc" ) @@ -112,1262 +106,6 @@ type clientStatus interface { methodsStatus() []StatusSnapshot } -// errPoolClientUnhealthy is an error to indicate that client in pool is unhealthy. -var errPoolClientUnhealthy = errors.New("pool client unhealthy") - -// clientStatusMonitor count error rate and other statistics for connection. -type clientStatusMonitor struct { - logger *zap.Logger - addr string - healthy *atomic.Uint32 - errorThreshold uint32 - - mu sync.RWMutex // protect counters - currentErrorCount uint32 - overallErrorCount uint64 - methods []*MethodStatus -} - -// values for healthy status of clientStatusMonitor. -const ( - // statusUnhealthyOnRequest is set when communication after dialing to the - // endpoint is failed due to immediate or accumulated errors, connection is - // available and pool should close it before re-establishing connection once again. - statusUnhealthyOnRequest = iota - - // statusHealthy is set when connection is ready to be used by the pool. - statusHealthy -) - -// MethodIndex index of method in list of statuses in clientStatusMonitor. -type MethodIndex int - -const ( - methodBalanceGet MethodIndex = iota - methodContainerPut - methodContainerGet - methodContainerList - methodContainerListStream - methodContainerDelete - methodEndpointInfo - methodNetworkInfo - methodNetMapSnapshot - methodObjectPut - methodObjectDelete - methodObjectGet - methodObjectHead - methodObjectRange - methodObjectPatch - methodSessionCreate - methodAPEManagerAddChain - methodAPEManagerRemoveChain - methodAPEManagerListChains - methodLast -) - -// String implements fmt.Stringer. -func (m MethodIndex) String() string { - switch m { - case methodBalanceGet: - return "balanceGet" - case methodContainerPut: - return "containerPut" - case methodContainerGet: - return "containerGet" - case methodContainerList: - return "containerList" - case methodContainerDelete: - return "containerDelete" - case methodEndpointInfo: - return "endpointInfo" - case methodNetworkInfo: - return "networkInfo" - case methodNetMapSnapshot: - return "netMapSnapshot" - case methodObjectPut: - return "objectPut" - case methodObjectPatch: - return "objectPatch" - case methodObjectDelete: - return "objectDelete" - case methodObjectGet: - return "objectGet" - case methodObjectHead: - return "objectHead" - case methodObjectRange: - return "objectRange" - case methodSessionCreate: - return "sessionCreate" - case methodAPEManagerAddChain: - return "apeManagerAddChain" - case methodAPEManagerRemoveChain: - return "apeManagerRemoveChain" - case methodAPEManagerListChains: - return "apeManagerListChains" - case methodLast: - return "it's a system name rather than a method" - default: - return "unknown" - } -} - -func newClientStatusMonitor(logger *zap.Logger, addr string, errorThreshold uint32) clientStatusMonitor { - methods := make([]*MethodStatus, methodLast) - for i := methodBalanceGet; i < methodLast; i++ { - methods[i] = &MethodStatus{name: i.String()} - } - - healthy := new(atomic.Uint32) - healthy.Store(statusHealthy) - - return clientStatusMonitor{ - logger: logger, - addr: addr, - healthy: healthy, - errorThreshold: errorThreshold, - methods: methods, - } -} - -// clientWrapper is used by default, alternative implementations are intended for testing purposes only. -type clientWrapper struct { - clientMutex sync.RWMutex - client *sdkClient.Client - dialed bool - prm wrapperPrm - - clientStatusMonitor -} - -// wrapperPrm is params to create clientWrapper. -type wrapperPrm struct { - logger *zap.Logger - address string - key ecdsa.PrivateKey - dialTimeout time.Duration - streamTimeout time.Duration - errorThreshold uint32 - responseInfoCallback func(sdkClient.ResponseMetaInfo) error - poolRequestInfoCallback func(RequestInfo) - dialOptions []grpc.DialOption - - gracefulCloseOnSwitchTimeout time.Duration -} - -// setAddress sets endpoint to connect in FrostFS network. -func (x *wrapperPrm) setAddress(address string) { - x.address = address -} - -// setKey sets sdkClient.Client private key to be used for the protocol communication by default. -func (x *wrapperPrm) setKey(key ecdsa.PrivateKey) { - x.key = key -} - -// setLogger sets sdkClient.Client logger. -func (x *wrapperPrm) setLogger(logger *zap.Logger) { - x.logger = logger -} - -// setDialTimeout sets the timeout for connection to be established. -func (x *wrapperPrm) setDialTimeout(timeout time.Duration) { - x.dialTimeout = timeout -} - -// setStreamTimeout sets the timeout for individual operations in streaming RPC. -func (x *wrapperPrm) setStreamTimeout(timeout time.Duration) { - x.streamTimeout = timeout -} - -// setErrorThreshold sets threshold after reaching which connection is considered unhealthy -// until Pool.startRebalance routing updates its status. -func (x *wrapperPrm) setErrorThreshold(threshold uint32) { - x.errorThreshold = threshold -} - -// setGracefulCloseOnSwitchTimeout specifies the timeout after which unhealthy client be closed during rebalancing -// if it will become healthy back. -// -// See also setErrorThreshold. -func (x *wrapperPrm) setGracefulCloseOnSwitchTimeout(timeout time.Duration) { - x.gracefulCloseOnSwitchTimeout = timeout -} - -// setPoolRequestCallback sets callback that will be invoked after every pool response. -func (x *wrapperPrm) setPoolRequestCallback(f func(RequestInfo)) { - x.poolRequestInfoCallback = f -} - -// setResponseInfoCallback sets callback that will be invoked after every response. -func (x *wrapperPrm) setResponseInfoCallback(f func(sdkClient.ResponseMetaInfo) error) { - x.responseInfoCallback = f -} - -// setGRPCDialOptions sets the gRPC dial options for new gRPC client connection. -func (x *wrapperPrm) setGRPCDialOptions(opts []grpc.DialOption) { - x.dialOptions = opts -} - -// newWrapper creates a clientWrapper that implements the client interface. -func newWrapper(prm wrapperPrm) *clientWrapper { - var cl sdkClient.Client - prmInit := sdkClient.PrmInit{ - Key: prm.key, - ResponseInfoCallback: prm.responseInfoCallback, - } - - cl.Init(prmInit) - - res := &clientWrapper{ - client: &cl, - clientStatusMonitor: newClientStatusMonitor(prm.logger, prm.address, prm.errorThreshold), - prm: prm, - } - - return res -} - -// dial establishes a connection to the server from the FrostFS network. -// Returns an error describing failure reason. If failed, the client -// SHOULD NOT be used. -func (c *clientWrapper) dial(ctx context.Context) error { - cl, err := c.getClient() - if err != nil { - return err - } - - prmDial := sdkClient.PrmDial{ - Endpoint: c.prm.address, - DialTimeout: c.prm.dialTimeout, - StreamTimeout: c.prm.streamTimeout, - GRPCDialOptions: c.prm.dialOptions, - } - - err = cl.Dial(ctx, prmDial) - c.setDialed(err == nil) - if err != nil { - return err - } - - return nil -} - -// restart recreates and redial inner sdk client. -func (c *clientWrapper) restart(ctx context.Context) error { - var cl sdkClient.Client - prmInit := sdkClient.PrmInit{ - Key: c.prm.key, - ResponseInfoCallback: c.prm.responseInfoCallback, - } - - cl.Init(prmInit) - - prmDial := sdkClient.PrmDial{ - Endpoint: c.prm.address, - DialTimeout: c.prm.dialTimeout, - StreamTimeout: c.prm.streamTimeout, - GRPCDialOptions: c.prm.dialOptions, - } - - // if connection is dialed before, to avoid routine / connection leak, - // pool has to close it and then initialize once again. - if c.isDialed() { - c.scheduleGracefulClose() - } - - err := cl.Dial(ctx, prmDial) - c.setDialed(err == nil) - if err != nil { - return err - } - - c.clientMutex.Lock() - c.client = &cl - c.clientMutex.Unlock() - - return nil -} - -func (c *clientWrapper) isDialed() bool { - c.mu.RLock() - defer c.mu.RUnlock() - return c.dialed -} - -func (c *clientWrapper) setDialed(dialed bool) { - c.mu.Lock() - c.dialed = dialed - c.mu.Unlock() -} - -func (c *clientWrapper) getClient() (*sdkClient.Client, error) { - c.clientMutex.RLock() - defer c.clientMutex.RUnlock() - if c.isHealthy() { - return c.client, nil - } - return nil, errPoolClientUnhealthy -} - -func (c *clientWrapper) getClientRaw() *sdkClient.Client { - c.clientMutex.RLock() - defer c.clientMutex.RUnlock() - return c.client -} - -// balanceGet invokes sdkClient.BalanceGet parse response status to error and return result as is. -func (c *clientWrapper) balanceGet(ctx context.Context, prm PrmBalanceGet) (accounting.Decimal, error) { - cl, err := c.getClient() - if err != nil { - return accounting.Decimal{}, err - } - - cliPrm := sdkClient.PrmBalanceGet{ - Account: prm.account, - } - - start := time.Now() - res, err := cl.BalanceGet(ctx, cliPrm) - c.incRequests(time.Since(start), methodBalanceGet) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return accounting.Decimal{}, fmt.Errorf("balance get on client: %w", err) - } - - return res.Amount(), nil -} - -// containerPut invokes sdkClient.ContainerPut parse response status to error and return result as is. -// It also waits for the container to appear on the network. -func (c *clientWrapper) containerPut(ctx context.Context, prm PrmContainerPut) (cid.ID, error) { - cl, err := c.getClient() - if err != nil { - return cid.ID{}, err - } - - start := time.Now() - res, err := cl.ContainerPut(ctx, prm.ClientParams) - c.incRequests(time.Since(start), methodContainerPut) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return cid.ID{}, fmt.Errorf("container put on client: %w", err) - } - - if prm.WaitParams == nil { - prm.WaitParams = defaultWaitParams() - } - if err = prm.WaitParams.CheckValidity(); err != nil { - return cid.ID{}, fmt.Errorf("invalid wait parameters: %w", err) - } - - idCnr := res.ID() - - getPrm := PrmContainerGet{ - ContainerID: idCnr, - Session: prm.ClientParams.Session, - } - - err = waitForContainerPresence(ctx, c, getPrm, prm.WaitParams) - if err = c.handleError(ctx, nil, err); err != nil { - return cid.ID{}, fmt.Errorf("wait container presence on client: %w", err) - } - - return idCnr, nil -} - -// containerGet invokes sdkClient.ContainerGet parse response status to error and return result as is. -func (c *clientWrapper) containerGet(ctx context.Context, prm PrmContainerGet) (container.Container, error) { - cl, err := c.getClient() - if err != nil { - return container.Container{}, err - } - - cliPrm := sdkClient.PrmContainerGet{ - ContainerID: &prm.ContainerID, - Session: prm.Session, - } - - start := time.Now() - res, err := cl.ContainerGet(ctx, cliPrm) - c.incRequests(time.Since(start), methodContainerGet) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return container.Container{}, fmt.Errorf("container get on client: %w", err) - } - - return res.Container(), nil -} - -// containerList invokes sdkClient.ContainerList parse response status to error and return result as is. -func (c *clientWrapper) containerList(ctx context.Context, prm PrmContainerList) ([]cid.ID, error) { - cl, err := c.getClient() - if err != nil { - return nil, err - } - - cliPrm := sdkClient.PrmContainerList{ - OwnerID: prm.OwnerID, - Session: prm.Session, - } - - start := time.Now() - res, err := cl.ContainerList(ctx, cliPrm) - c.incRequests(time.Since(start), methodContainerList) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return nil, fmt.Errorf("container list on client: %w", err) - } - return res.Containers(), nil -} - -// PrmListStream groups parameters of ListContainersStream operation. -type PrmListStream struct { - OwnerID user.ID - - Session *session.Container -} - -// ResListStream is designed to read list of object identifiers from FrostFS system. -// -// Must be initialized using Pool.ListContainersStream, any other usage is unsafe. -type ResListStream struct { - r *sdkClient.ContainerListReader - handleError func(context.Context, apistatus.Status, error) error -} - -// Read reads another list of the container identifiers. -func (x *ResListStream) Read(buf []cid.ID) (int, error) { - n, ok := x.r.Read(buf) - if !ok { - res, err := x.r.Close() - if err == nil { - return n, io.EOF - } - - var status apistatus.Status - if res != nil { - status = res.Status() - } - err = x.handleError(nil, status, err) - - return n, err - } - - return n, nil -} - -// Iterate iterates over the list of found container identifiers. -// f can return true to stop iteration earlier. -// -// Returns an error if container can't be read. -func (x *ResListStream) Iterate(f func(cid.ID) bool) error { - return x.r.Iterate(f) -} - -// Close ends reading list of the matched containers and returns the result of the operation -// along with the final results. Must be called after using the ResListStream. -func (x *ResListStream) Close() { - _, _ = x.r.Close() -} - -// containerList invokes sdkClient.ContainerList parse response status to error and return result as is. -func (c *clientWrapper) containerListStream(ctx context.Context, prm PrmListStream) (ResListStream, error) { - cl, err := c.getClient() - if err != nil { - return ResListStream{}, err - } - - cliPrm := sdkClient.PrmContainerListStream{ - OwnerID: prm.OwnerID, - Session: prm.Session, - } - - res, err := cl.ContainerListInit(ctx, cliPrm) - if err = c.handleError(ctx, nil, err); err != nil { - return ResListStream{}, fmt.Errorf("init container listing on client: %w", err) - } - return ResListStream{r: res, handleError: c.handleError}, nil -} - -// containerDelete invokes sdkClient.ContainerDelete parse response status to error. -// It also waits for the container to be removed from the network. -func (c *clientWrapper) containerDelete(ctx context.Context, prm PrmContainerDelete) error { - cl, err := c.getClient() - if err != nil { - return err - } - - cliPrm := sdkClient.PrmContainerDelete{ - ContainerID: &prm.ContainerID, - Session: prm.Session, - } - - start := time.Now() - res, err := cl.ContainerDelete(ctx, cliPrm) - c.incRequests(time.Since(start), methodContainerDelete) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return fmt.Errorf("container delete on client: %w", err) - } - - if prm.WaitParams == nil { - prm.WaitParams = defaultWaitParams() - } - if err := prm.WaitParams.CheckValidity(); err != nil { - return fmt.Errorf("invalid wait parameters: %w", err) - } - - getPrm := PrmContainerGet{ - ContainerID: prm.ContainerID, - Session: prm.Session, - } - - return waitForContainerRemoved(ctx, c, getPrm, prm.WaitParams) -} - -// apeManagerAddChain invokes sdkClient.APEManagerAddChain and parse response status to error. -func (c *clientWrapper) apeManagerAddChain(ctx context.Context, prm PrmAddAPEChain) error { - cl, err := c.getClient() - if err != nil { - return err - } - - cliPrm := sdkClient.PrmAPEManagerAddChain{ - ChainTarget: prm.Target, - Chain: prm.Chain, - } - - start := time.Now() - res, err := cl.APEManagerAddChain(ctx, cliPrm) - c.incRequests(time.Since(start), methodAPEManagerAddChain) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return fmt.Errorf("add chain error: %w", err) - } - - return nil -} - -// apeManagerRemoveChain invokes sdkClient.APEManagerRemoveChain and parse response status to error. -func (c *clientWrapper) apeManagerRemoveChain(ctx context.Context, prm PrmRemoveAPEChain) error { - cl, err := c.getClient() - if err != nil { - return err - } - - cliPrm := sdkClient.PrmAPEManagerRemoveChain{ - ChainTarget: prm.Target, - ChainID: prm.ChainID, - } - - start := time.Now() - res, err := cl.APEManagerRemoveChain(ctx, cliPrm) - c.incRequests(time.Since(start), methodAPEManagerRemoveChain) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return fmt.Errorf("remove chain error: %w", err) - } - - return nil -} - -// apeManagerListChains invokes sdkClient.APEManagerListChains. Returns chains and parsed response status to error. -func (c *clientWrapper) apeManagerListChains(ctx context.Context, prm PrmListAPEChains) ([]ape.Chain, error) { - cl, err := c.getClient() - if err != nil { - return nil, err - } - - cliPrm := sdkClient.PrmAPEManagerListChains{ - ChainTarget: prm.Target, - } - - start := time.Now() - res, err := cl.APEManagerListChains(ctx, cliPrm) - c.incRequests(time.Since(start), methodAPEManagerListChains) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return nil, fmt.Errorf("list chains error: %w", err) - } - - return res.Chains, nil -} - -// endpointInfo invokes sdkClient.EndpointInfo parse response status to error and return result as is. -func (c *clientWrapper) endpointInfo(ctx context.Context, _ prmEndpointInfo) (netmap.NodeInfo, error) { - cl, err := c.getClient() - if err != nil { - return netmap.NodeInfo{}, err - } - - return c.endpointInfoRaw(ctx, cl) -} - -func (c *clientWrapper) healthcheck(ctx context.Context) (netmap.NodeInfo, error) { - cl := c.getClientRaw() - return c.endpointInfoRaw(ctx, cl) -} - -func (c *clientWrapper) endpointInfoRaw(ctx context.Context, cl *sdkClient.Client) (netmap.NodeInfo, error) { - start := time.Now() - res, err := cl.EndpointInfo(ctx, sdkClient.PrmEndpointInfo{}) - c.incRequests(time.Since(start), methodEndpointInfo) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return netmap.NodeInfo{}, fmt.Errorf("endpoint info on client: %w", err) - } - - return res.NodeInfo(), nil -} - -// networkInfo invokes sdkClient.NetworkInfo parse response status to error and return result as is. -func (c *clientWrapper) networkInfo(ctx context.Context, _ prmNetworkInfo) (netmap.NetworkInfo, error) { - cl, err := c.getClient() - if err != nil { - return netmap.NetworkInfo{}, err - } - - start := time.Now() - res, err := cl.NetworkInfo(ctx, sdkClient.PrmNetworkInfo{}) - c.incRequests(time.Since(start), methodNetworkInfo) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return netmap.NetworkInfo{}, fmt.Errorf("network info on client: %w", err) - } - - return res.Info(), nil -} - -// networkInfo invokes sdkClient.NetworkInfo parse response status to error and return result as is. -func (c *clientWrapper) netMapSnapshot(ctx context.Context, _ prmNetMapSnapshot) (netmap.NetMap, error) { - cl, err := c.getClient() - if err != nil { - return netmap.NetMap{}, err - } - - start := time.Now() - res, err := cl.NetMapSnapshot(ctx, sdkClient.PrmNetMapSnapshot{}) - c.incRequests(time.Since(start), methodNetMapSnapshot) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return netmap.NetMap{}, fmt.Errorf("network map snapshot on client: %w", err) - } - - return res.NetMap(), nil -} - -// objectPatch patches object in FrostFS. -func (c *clientWrapper) objectPatch(ctx context.Context, prm PrmObjectPatch) (ResPatchObject, error) { - cl, err := c.getClient() - if err != nil { - return ResPatchObject{}, err - } - - start := time.Now() - pObj, err := cl.ObjectPatchInit(ctx, sdkClient.PrmObjectPatch{ - Address: prm.addr, - Session: prm.stoken, - Key: prm.key, - BearerToken: prm.btoken, - MaxChunkLength: prm.maxPayloadPatchChunkLength, - }) - if err = c.handleError(ctx, nil, err); err != nil { - return ResPatchObject{}, fmt.Errorf("init patching on API client: %w", err) - } - c.incRequests(time.Since(start), methodObjectPatch) - - start = time.Now() - attrPatchSuccess := pObj.PatchAttributes(ctx, prm.newAttrs, prm.replaceAttrs) - c.incRequests(time.Since(start), methodObjectPatch) - - if attrPatchSuccess { - start = time.Now() - _ = pObj.PatchPayload(ctx, prm.rng, prm.payload) - c.incRequests(time.Since(start), methodObjectPatch) - } - - res, err := pObj.Close(ctx) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return ResPatchObject{}, fmt.Errorf("client failure: %w", err) - } - - return ResPatchObject{ObjectID: res.ObjectID()}, nil -} - -// objectPut writes object to FrostFS. -func (c *clientWrapper) objectPut(ctx context.Context, prm PrmObjectPut) (ResPutObject, error) { - if prm.bufferMaxSize == 0 { - prm.bufferMaxSize = defaultBufferMaxSizeForPut - } - - if prm.clientCut { - return c.objectPutClientCut(ctx, prm) - } - - return c.objectPutServerCut(ctx, prm) -} - -func (c *clientWrapper) objectPutServerCut(ctx context.Context, prm PrmObjectPut) (ResPutObject, error) { - cl, err := c.getClient() - if err != nil { - return ResPutObject{}, err - } - - cliPrm := sdkClient.PrmObjectPutInit{ - CopiesNumber: prm.copiesNumber, - Session: prm.stoken, - Key: prm.key, - BearerToken: prm.btoken, - } - - start := time.Now() - wObj, err := cl.ObjectPutInit(ctx, cliPrm) - c.incRequests(time.Since(start), methodObjectPut) - if err = c.handleError(ctx, nil, err); err != nil { - return ResPutObject{}, fmt.Errorf("init writing on API client: %w", err) - } - - if wObj.WriteHeader(ctx, prm.hdr) { - sz := prm.hdr.PayloadSize() - - if data := prm.hdr.Payload(); len(data) > 0 { - if prm.payload != nil { - prm.payload = io.MultiReader(bytes.NewReader(data), prm.payload) - } else { - prm.payload = bytes.NewReader(data) - sz = uint64(len(data)) - } - } - - if prm.payload != nil { - if sz == 0 || sz > prm.bufferMaxSize { - sz = prm.bufferMaxSize - } - - buf := make([]byte, sz) - - var n int - - for { - n, err = prm.payload.Read(buf) - if n > 0 { - start = time.Now() - successWrite := wObj.WritePayloadChunk(ctx, buf[:n]) - c.incRequests(time.Since(start), methodObjectPut) - if !successWrite { - break - } - - continue - } - - if errors.Is(err, io.EOF) { - break - } - - return ResPutObject{}, fmt.Errorf("read payload: %w", c.handleError(ctx, nil, err)) - } - } - } - - res, err := wObj.Close(ctx) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { // here err already carries both status and client errors - return ResPutObject{}, fmt.Errorf("client failure: %w", err) - } - - return ResPutObject{ - ObjectID: res.StoredObjectID(), - Epoch: res.StoredEpoch(), - }, nil -} - -func (c *clientWrapper) objectPutClientCut(ctx context.Context, prm PrmObjectPut) (ResPutObject, error) { - putInitPrm := PrmObjectPutClientCutInit{ - PrmObjectPut: prm, - } - - start := time.Now() - wObj, err := c.objectPutInitTransformer(putInitPrm) - c.incRequests(time.Since(start), methodObjectPut) - if err = c.handleError(ctx, nil, err); err != nil { - return ResPutObject{}, fmt.Errorf("init writing on API client: %w", err) - } - - if wObj.WriteHeader(ctx, prm.hdr) { - sz := prm.hdr.PayloadSize() - - if data := prm.hdr.Payload(); len(data) > 0 { - if prm.payload != nil { - prm.payload = io.MultiReader(bytes.NewReader(data), prm.payload) - } else { - prm.payload = bytes.NewReader(data) - sz = uint64(len(data)) - } - } - - if prm.payload != nil { - if sz == 0 || sz > prm.bufferMaxSize { - sz = prm.bufferMaxSize - } - - buf := make([]byte, sz) - - var n int - - for { - n, err = prm.payload.Read(buf) - if n > 0 { - start = time.Now() - successWrite := wObj.WritePayloadChunk(ctx, buf[:n]) - c.incRequests(time.Since(start), methodObjectPut) - if !successWrite { - break - } - - continue - } - - if errors.Is(err, io.EOF) { - break - } - - return ResPutObject{}, fmt.Errorf("read payload: %w", c.handleError(ctx, nil, err)) - } - } - } - - res, err := wObj.Close(ctx) - var st apistatus.Status - if res != nil { - st = res.Status - } - if err = c.handleError(ctx, st, err); err != nil { // here err already carries both status and client errors - return ResPutObject{}, fmt.Errorf("client failure: %w", err) - } - - return ResPutObject{ - ObjectID: res.OID, - Epoch: res.Epoch, - }, nil -} - -// objectDelete invokes sdkClient.ObjectDelete parse response status to error. -func (c *clientWrapper) objectDelete(ctx context.Context, prm PrmObjectDelete) error { - cl, err := c.getClient() - if err != nil { - return err - } - - cnr := prm.addr.Container() - obj := prm.addr.Object() - - cliPrm := sdkClient.PrmObjectDelete{ - BearerToken: prm.btoken, - Session: prm.stoken, - ContainerID: &cnr, - ObjectID: &obj, - Key: prm.key, - } - - start := time.Now() - res, err := cl.ObjectDelete(ctx, cliPrm) - c.incRequests(time.Since(start), methodObjectDelete) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return fmt.Errorf("delete object on client: %w", err) - } - return nil -} - -// objectGet returns reader for object. -func (c *clientWrapper) objectGet(ctx context.Context, prm PrmObjectGet) (ResGetObject, error) { - cl, err := c.getClient() - if err != nil { - return ResGetObject{}, err - } - - prmCnr := prm.addr.Container() - prmObj := prm.addr.Object() - - cliPrm := sdkClient.PrmObjectGet{ - BearerToken: prm.btoken, - Session: prm.stoken, - ContainerID: &prmCnr, - ObjectID: &prmObj, - Key: prm.key, - } - - var res ResGetObject - - rObj, err := cl.ObjectGetInit(ctx, cliPrm) - if err = c.handleError(ctx, nil, err); err != nil { - return ResGetObject{}, fmt.Errorf("init object reading on client: %w", err) - } - - start := time.Now() - successReadHeader := rObj.ReadHeader(&res.Header) - c.incRequests(time.Since(start), methodObjectGet) - if !successReadHeader { - rObjRes, err := rObj.Close() - var st apistatus.Status - if rObjRes != nil { - st = rObjRes.Status() - } - err = c.handleError(ctx, st, err) - return res, fmt.Errorf("read header: %w", err) - } - - res.Payload = &objectReadCloser{ - reader: rObj, - elapsedTimeCallback: func(elapsed time.Duration) { - c.incRequests(elapsed, methodObjectGet) - }, - } - - return res, nil -} - -// objectHead invokes sdkClient.ObjectHead parse response status to error and return result as is. -func (c *clientWrapper) objectHead(ctx context.Context, prm PrmObjectHead) (object.Object, error) { - cl, err := c.getClient() - if err != nil { - return object.Object{}, err - } - - prmCnr := prm.addr.Container() - prmObj := prm.addr.Object() - - cliPrm := sdkClient.PrmObjectHead{ - BearerToken: prm.btoken, - Session: prm.stoken, - Raw: prm.raw, - ContainerID: &prmCnr, - ObjectID: &prmObj, - Key: prm.key, - } - - var obj object.Object - - start := time.Now() - res, err := cl.ObjectHead(ctx, cliPrm) - c.incRequests(time.Since(start), methodObjectHead) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return obj, fmt.Errorf("read object header via client: %w", err) - } - if !res.ReadHeader(&obj) { - return obj, errors.New("missing object header in response") - } - - return obj, nil -} - -// objectRange returns object range reader. -func (c *clientWrapper) objectRange(ctx context.Context, prm PrmObjectRange) (ResObjectRange, error) { - cl, err := c.getClient() - if err != nil { - return ResObjectRange{}, err - } - - prmCnr := prm.addr.Container() - prmObj := prm.addr.Object() - - cliPrm := sdkClient.PrmObjectRange{ - BearerToken: prm.btoken, - Session: prm.stoken, - ContainerID: &prmCnr, - ObjectID: &prmObj, - Offset: prm.off, - Length: prm.ln, - Key: prm.key, - } - - start := time.Now() - res, err := cl.ObjectRangeInit(ctx, cliPrm) - c.incRequests(time.Since(start), methodObjectRange) - if err = c.handleError(ctx, nil, err); err != nil { - return ResObjectRange{}, fmt.Errorf("init payload range reading on client: %w", err) - } - - return ResObjectRange{ - payload: res, - elapsedTimeCallback: func(elapsed time.Duration) { - c.incRequests(elapsed, methodObjectRange) - }, - }, nil -} - -// objectSearch invokes sdkClient.ObjectSearchInit parse response status to error and return result as is. -func (c *clientWrapper) objectSearch(ctx context.Context, prm PrmObjectSearch) (ResObjectSearch, error) { - cl, err := c.getClient() - if err != nil { - return ResObjectSearch{}, err - } - - cliPrm := sdkClient.PrmObjectSearch{ - ContainerID: &prm.cnrID, - Filters: prm.filters, - Session: prm.stoken, - BearerToken: prm.btoken, - Key: prm.key, - } - - res, err := cl.ObjectSearchInit(ctx, cliPrm) - if err = c.handleError(ctx, nil, err); err != nil { - return ResObjectSearch{}, fmt.Errorf("init object searching on client: %w", err) - } - - return ResObjectSearch{r: res, handleError: c.handleError}, nil -} - -// sessionCreate invokes sdkClient.SessionCreate parse response status to error and return result as is. -func (c *clientWrapper) sessionCreate(ctx context.Context, prm prmCreateSession) (resCreateSession, error) { - cl, err := c.getClient() - if err != nil { - return resCreateSession{}, err - } - - cliPrm := sdkClient.PrmSessionCreate{ - Expiration: prm.exp, - Key: &prm.key, - } - - start := time.Now() - res, err := cl.SessionCreate(ctx, cliPrm) - c.incRequests(time.Since(start), methodSessionCreate) - var st apistatus.Status - if res != nil { - st = res.Status() - } - if err = c.handleError(ctx, st, err); err != nil { - return resCreateSession{}, fmt.Errorf("session creation on client: %w", err) - } - - return resCreateSession{ - id: res.ID(), - sessionKey: res.PublicKey(), - }, nil -} - -func (c *clientStatusMonitor) isHealthy() bool { - return c.healthy.Load() == statusHealthy -} - -func (c *clientStatusMonitor) setHealthy() { - c.healthy.Store(statusHealthy) -} - -func (c *clientStatusMonitor) setUnhealthy() { - c.healthy.Store(statusUnhealthyOnRequest) -} - -func (c *clientStatusMonitor) address() string { - return c.addr -} - -func (c *clientStatusMonitor) incErrorRate() { - c.mu.Lock() - c.currentErrorCount++ - c.overallErrorCount++ - - thresholdReached := c.currentErrorCount >= c.errorThreshold - if thresholdReached { - c.setUnhealthy() - c.currentErrorCount = 0 - } - c.mu.Unlock() - - if thresholdReached { - c.log(zapcore.WarnLevel, "error threshold reached", - zap.String("address", c.addr), zap.Uint32("threshold", c.errorThreshold)) - } -} - -func (c *clientStatusMonitor) incErrorRateToUnhealthy(err error) { - c.mu.Lock() - c.currentErrorCount = 0 - c.overallErrorCount++ - c.setUnhealthy() - c.mu.Unlock() - - c.log(zapcore.WarnLevel, "explicitly mark node unhealthy", zap.String("address", c.addr), zap.Error(err)) -} - -func (c *clientStatusMonitor) log(level zapcore.Level, msg string, fields ...zap.Field) { - if c.logger == nil { - return - } - - c.logger.Log(level, msg, fields...) -} - -func (c *clientStatusMonitor) currentErrorRate() uint32 { - c.mu.RLock() - defer c.mu.RUnlock() - return c.currentErrorCount -} - -func (c *clientStatusMonitor) overallErrorRate() uint64 { - c.mu.RLock() - defer c.mu.RUnlock() - return c.overallErrorCount -} - -func (c *clientStatusMonitor) methodsStatus() []StatusSnapshot { - result := make([]StatusSnapshot, len(c.methods)) - for i, val := range c.methods { - result[i] = val.Snapshot() - } - - return result -} - -func (c *clientWrapper) incRequests(elapsed time.Duration, method MethodIndex) { - methodStat := c.methods[method] - methodStat.IncRequests(elapsed) - if c.prm.poolRequestInfoCallback != nil { - c.prm.poolRequestInfoCallback(RequestInfo{ - Address: c.prm.address, - Method: method, - Elapsed: elapsed, - }) - } -} - -func (c *clientWrapper) close() error { - if !c.isDialed() { - return nil - } - if cl := c.getClientRaw(); cl != nil { - return cl.Close() - } - return nil -} - -func (c *clientWrapper) scheduleGracefulClose() { - cl := c.getClientRaw() - if cl == nil { - return - } - - time.AfterFunc(c.prm.gracefulCloseOnSwitchTimeout, func() { - if err := cl.Close(); err != nil { - c.log(zap.DebugLevel, "close unhealthy client during rebalance", zap.String("address", c.address()), zap.Error(err)) - } - }) -} - -func (c *clientStatusMonitor) handleError(ctx context.Context, st apistatus.Status, err error) error { - if stErr := apistatus.ErrFromStatus(st); stErr != nil { - switch stErr.(type) { - case *apistatus.ServerInternal, - *apistatus.WrongMagicNumber, - *apistatus.SignatureVerification: - c.incErrorRate() - case *apistatus.NodeUnderMaintenance: - c.incErrorRateToUnhealthy(stErr) - } - - if err == nil { - err = stErr - } - - return err - } - - if err != nil { - if needCountError(ctx, err) { - if sdkClient.IsErrNodeUnderMaintenance(err) { - c.incErrorRateToUnhealthy(err) - } else { - c.incErrorRate() - } - } - - return err - } - - return nil -} - -func needCountError(ctx context.Context, err error) bool { - // non-status logic error that could be returned - // from the SDK client; should not be considered - // as a connection error - var siErr *object.SplitInfoError - if errors.As(err, &siErr) { - return false - } - var eiErr *object.ECInfoError - if errors.As(err, &eiErr) { - return false - } - - if ctx != nil && errors.Is(ctx.Err(), context.Canceled) { - return false - } - - return true -} - -// clientBuilder is a type alias of client constructors which open connection -// to the given endpoint. -type clientBuilder = func(endpoint string) client - -// RequestInfo groups info about pool request. -type RequestInfo struct { - Address string - Method MethodIndex - Elapsed time.Duration -} - // InitParameters contains values used to initialize connection Pool. type InitParameters struct { key *ecdsa.PrivateKey @@ -2006,25 +744,15 @@ type resCreateSession struct { // // See pool package overview to get some examples. type Pool struct { - innerPools []*innerPool - key *ecdsa.PrivateKey - cancel context.CancelFunc - closedCh chan struct{} - cache *sessionCache - stokenDuration uint64 - rebalanceParams rebalanceParameters - clientBuilder clientBuilder - logger *zap.Logger + manager *connectionManager + logger *zap.Logger + key *ecdsa.PrivateKey + cache *sessionCache + stokenDuration uint64 maxObjectSize uint64 } -type innerPool struct { - lock sync.RWMutex - sampler *sampler - clients []client -} - const ( defaultSessionTokenExpirationDuration = 100 // in epochs defaultErrorThreshold = 100 @@ -2038,17 +766,10 @@ const ( defaultBufferMaxSizeForPut = 3 * 1024 * 1024 // 3 MB ) -// NewPool creates connection pool using parameters. +// NewPool returns an instance of Pool configured according to the parameters. +// +// Before using Pool, you MUST call Dial. func NewPool(options InitParameters) (*Pool, error) { - if options.key == nil { - return nil, fmt.Errorf("missed required parameter 'Key'") - } - - nodesParams, err := adjustNodeParams(options.nodeParams) - if err != nil { - return nil, err - } - cache, err := newCache(options.sessionExpirationDuration) if err != nil { return nil, fmt.Errorf("couldn't create cache: %w", err) @@ -2056,18 +777,17 @@ func NewPool(options InitParameters) (*Pool, error) { fillDefaultInitParams(&options, cache) + manager, err := newConnectionManager(options) + if err != nil { + return nil, err + } + pool := &Pool{ - key: options.key, cache: cache, + key: options.key, logger: options.logger, + manager: manager, stokenDuration: options.sessionExpirationDuration, - rebalanceParams: rebalanceParameters{ - nodesParams: nodesParams, - nodeRequestTimeout: options.healthcheckTimeout, - clientRebalanceInterval: options.clientRebalanceInterval, - sessionExpirationDuration: options.sessionExpirationDuration, - }, - clientBuilder: options.clientBuilder, } return pool, nil @@ -2082,66 +802,39 @@ func NewPool(options InitParameters) (*Pool, error) { // // See also InitParameters.SetClientRebalanceInterval. func (p *Pool) Dial(ctx context.Context) error { - inner := make([]*innerPool, len(p.rebalanceParams.nodesParams)) - var atLeastOneHealthy bool - - for i, params := range p.rebalanceParams.nodesParams { - clients := make([]client, len(params.weights)) - for j, addr := range params.addresses { - clients[j] = p.clientBuilder(addr) - if err := clients[j].dial(ctx); err != nil { - p.log(zap.WarnLevel, "failed to build client", zap.String("address", addr), zap.Error(err)) - continue - } - - var st session.Object - err := initSessionForDuration(ctx, &st, clients[j], p.rebalanceParams.sessionExpirationDuration, *p.key, false) - if err != nil { - clients[j].setUnhealthy() - p.log(zap.WarnLevel, "failed to create frostfs session token for client", - zap.String("address", addr), zap.Error(err)) - continue - } - - _ = p.cache.Put(formCacheKey(addr, p.key, false), st) - atLeastOneHealthy = true - } - source := rand.NewSource(time.Now().UnixNano()) - sampl := newSampler(params.weights, source) - - inner[i] = &innerPool{ - sampler: sampl, - clients: clients, - } + err := p.manager.dial(ctx) + if err != nil { + return err } + var atLeastOneHealthy bool + p.manager.iterate(func(cl client) { + var st session.Object + err := initSessionForDuration(ctx, &st, cl, p.manager.rebalanceParams.sessionExpirationDuration, *p.key, false) + if err != nil { + if p.logger != nil { + p.logger.Log(zap.WarnLevel, "failed to create frostfs session token for client", + zap.String("address", cl.address()), zap.Error(err)) + } + return + } + + _ = p.cache.Put(formCacheKey(cl.address(), p.key, false), st) + atLeastOneHealthy = true + }) + if !atLeastOneHealthy { return fmt.Errorf("at least one node must be healthy") } - ctx, cancel := context.WithCancel(ctx) - p.cancel = cancel - p.closedCh = make(chan struct{}) - p.innerPools = inner - ni, err := p.NetworkInfo(ctx) if err != nil { return fmt.Errorf("get network info for max object size: %w", err) } p.maxObjectSize = ni.MaxObjectSize() - - go p.startRebalance(ctx) return nil } -func (p *Pool) log(level zapcore.Level, msg string, fields ...zap.Field) { - if p.logger == nil { - return - } - - p.logger.Log(level, msg, fields...) -} - func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { if params.sessionExpirationDuration == 0 { params.sessionExpirationDuration = defaultSessionTokenExpirationDuration @@ -2196,204 +889,6 @@ func fillDefaultInitParams(params *InitParameters, cache *sessionCache) { } } -func adjustNodeParams(nodeParams []NodeParam) ([]*nodesParam, error) { - if len(nodeParams) == 0 { - return nil, errors.New("no FrostFS peers configured") - } - - nodesParamsMap := make(map[int]*nodesParam) - for _, param := range nodeParams { - nodes, ok := nodesParamsMap[param.priority] - if !ok { - nodes = &nodesParam{priority: param.priority} - } - nodes.addresses = append(nodes.addresses, param.address) - nodes.weights = append(nodes.weights, param.weight) - nodesParamsMap[param.priority] = nodes - } - - nodesParams := make([]*nodesParam, 0, len(nodesParamsMap)) - for _, nodes := range nodesParamsMap { - nodes.weights = adjustWeights(nodes.weights) - nodesParams = append(nodesParams, nodes) - } - - sort.Slice(nodesParams, func(i, j int) bool { - return nodesParams[i].priority < nodesParams[j].priority - }) - - return nodesParams, nil -} - -// startRebalance runs loop to monitor connection healthy status. -func (p *Pool) startRebalance(ctx context.Context) { - ticker := time.NewTicker(p.rebalanceParams.clientRebalanceInterval) - defer ticker.Stop() - - buffers := make([][]float64, len(p.rebalanceParams.nodesParams)) - for i, params := range p.rebalanceParams.nodesParams { - buffers[i] = make([]float64, len(params.weights)) - } - - for { - select { - case <-ctx.Done(): - close(p.closedCh) - return - case <-ticker.C: - p.updateNodesHealth(ctx, buffers) - ticker.Reset(p.rebalanceParams.clientRebalanceInterval) - } - } -} - -func (p *Pool) updateNodesHealth(ctx context.Context, buffers [][]float64) { - wg := sync.WaitGroup{} - for i, inner := range p.innerPools { - wg.Add(1) - - bufferWeights := buffers[i] - go func(i int, _ *innerPool) { - defer wg.Done() - p.updateInnerNodesHealth(ctx, i, bufferWeights) - }(i, inner) - } - wg.Wait() -} - -func (p *Pool) updateInnerNodesHealth(ctx context.Context, i int, bufferWeights []float64) { - if i > len(p.innerPools)-1 { - return - } - pool := p.innerPools[i] - options := p.rebalanceParams - - healthyChanged := new(atomic.Bool) - wg := sync.WaitGroup{} - - for j, cli := range pool.clients { - wg.Add(1) - go func(j int, cli client) { - defer wg.Done() - - tctx, c := context.WithTimeout(ctx, options.nodeRequestTimeout) - defer c() - - changed, err := restartIfUnhealthy(tctx, cli) - healthy := err == nil - if healthy { - bufferWeights[j] = options.nodesParams[i].weights[j] - } else { - bufferWeights[j] = 0 - p.cache.DeleteByPrefix(cli.address()) - } - - if changed { - fields := []zap.Field{zap.String("address", cli.address()), zap.Bool("healthy", healthy)} - if err != nil { - fields = append(fields, zap.String("reason", err.Error())) - } - - p.log(zap.DebugLevel, "health has changed", fields...) - healthyChanged.Store(true) - } - }(j, cli) - } - wg.Wait() - - if healthyChanged.Load() { - probabilities := adjustWeights(bufferWeights) - source := rand.NewSource(time.Now().UnixNano()) - pool.lock.Lock() - pool.sampler = newSampler(probabilities, source) - pool.lock.Unlock() - } -} - -// restartIfUnhealthy checks healthy status of client and recreate it if status is unhealthy. -// Indicating if status was changed by this function call and returns error that caused unhealthy status. -func restartIfUnhealthy(ctx context.Context, c client) (changed bool, err error) { - defer func() { - if err != nil { - c.setUnhealthy() - } else { - c.setHealthy() - } - }() - - wasHealthy := c.isHealthy() - - if res, err := c.healthcheck(ctx); err == nil { - if res.Status().IsMaintenance() { - return wasHealthy, new(apistatus.NodeUnderMaintenance) - } - - return !wasHealthy, nil - } - - if err = c.restart(ctx); err != nil { - return wasHealthy, err - } - - res, err := c.healthcheck(ctx) - if err != nil { - return wasHealthy, err - } - - if res.Status().IsMaintenance() { - return wasHealthy, new(apistatus.NodeUnderMaintenance) - } - - return !wasHealthy, nil -} - -func adjustWeights(weights []float64) []float64 { - adjusted := make([]float64, len(weights)) - sum := 0.0 - for _, weight := range weights { - sum += weight - } - if sum > 0 { - for i, weight := range weights { - adjusted[i] = weight / sum - } - } - - return adjusted -} - -func (p *Pool) connection() (client, error) { - for _, inner := range p.innerPools { - cp, err := inner.connection() - if err == nil { - return cp, nil - } - } - - return nil, errors.New("no healthy client") -} - -func (p *innerPool) connection() (client, error) { - p.lock.RLock() // need lock because of using p.sampler - defer p.lock.RUnlock() - if len(p.clients) == 1 { - cp := p.clients[0] - if cp.isHealthy() { - return cp, nil - } - return nil, errors.New("no healthy client") - } - attempts := 3 * len(p.clients) - for range attempts { - i := p.sampler.Next() - if cp := p.clients[i]; cp.isHealthy() { - return cp, nil - } - } - - return nil, errors.New("no healthy client") -} - func formCacheKey(address string, key *ecdsa.PrivateKey, clientCut bool) string { k := keys.PrivateKey{PrivateKey: *key} @@ -2484,32 +979,33 @@ type callContext struct { sessionClientCut bool } -func (p *Pool) initCallContext(ctx *callContext, cfg prmCommon, prmCtx prmContext) error { - cp, err := p.connection() +func (p *Pool) initCall(ctxCall *callContext, cfg prmCommon, prmCtx prmContext) error { + p.fillAppropriateKey(&cfg) + cp, err := p.manager.connection() if err != nil { return err } - ctx.key = cfg.key - if ctx.key == nil { + ctxCall.key = cfg.key + if ctxCall.key == nil { // use pool key if caller didn't specify its own - ctx.key = p.key + ctxCall.key = p.key } - ctx.endpoint = cp.address() - ctx.client = cp + ctxCall.endpoint = cp.address() + ctxCall.client = cp - if ctx.sessionTarget != nil && cfg.stoken != nil { - ctx.sessionTarget(*cfg.stoken) + if ctxCall.sessionTarget != nil && cfg.stoken != nil { + ctxCall.sessionTarget(*cfg.stoken) } // note that we don't override session provided by the caller - ctx.sessionDefault = cfg.stoken == nil && prmCtx.defaultSession - if ctx.sessionDefault { - ctx.sessionVerb = prmCtx.verb - ctx.sessionCnr = prmCtx.cnr - ctx.sessionObjSet = prmCtx.objSet - ctx.sessionObjs = prmCtx.objs + ctxCall.sessionDefault = cfg.stoken == nil && prmCtx.defaultSession + if ctxCall.sessionDefault { + ctxCall.sessionVerb = prmCtx.verb + ctxCall.sessionCnr = prmCtx.cnr + ctxCall.sessionObjSet = prmCtx.objSet + ctxCall.sessionObjs = prmCtx.objs } return err @@ -2586,18 +1082,14 @@ type ResPatchObject struct { } // PatchObject patches an object through a remote server using FrostFS API protocol. -// -// Main return value MUST NOT be processed on an erroneous return. func (p *Pool) PatchObject(ctx context.Context, prm PrmObjectPatch) (ResPatchObject, error) { var prmCtx prmContext prmCtx.useDefaultSession() prmCtx.useVerb(session.VerbObjectPatch) prmCtx.useContainer(prm.addr.Container()) - p.fillAppropriateKey(&prm.prmCommon) - var ctxCall callContext - if err := p.initCallContext(&ctxCall, prm.prmCommon, prmCtx); err != nil { + if err := p.initCall(&ctxCall, prm.prmCommon, prmCtx); err != nil { return ResPatchObject{}, fmt.Errorf("init call context: %w", err) } @@ -2619,8 +1111,6 @@ func (p *Pool) PatchObject(ctx context.Context, prm PrmObjectPatch) (ResPatchObj } // PutObject writes an object through a remote server using FrostFS API protocol. -// -// Main return value MUST NOT be processed on an erroneous return. func (p *Pool) PutObject(ctx context.Context, prm PrmObjectPut) (ResPutObject, error) { cnr, _ := prm.hdr.ContainerID() @@ -2629,11 +1119,9 @@ func (p *Pool) PutObject(ctx context.Context, prm PrmObjectPut) (ResPutObject, e prmCtx.useVerb(session.VerbObjectPut) prmCtx.useContainer(cnr) - p.fillAppropriateKey(&prm.prmCommon) - var ctxCall callContext ctxCall.sessionClientCut = prm.clientCut - if err := p.initCallContext(&ctxCall, prm.prmCommon, prmCtx); err != nil { + if err := p.initCall(&ctxCall, prm.prmCommon, prmCtx); err != nil { return ResPutObject{}, fmt.Errorf("init call context: %w", err) } @@ -2686,12 +1174,10 @@ func (p *Pool) DeleteObject(ctx context.Context, prm PrmObjectDelete) error { } } - p.fillAppropriateKey(&prm.prmCommon) - var cc callContext cc.sessionTarget = prm.UseSession - err := p.initCallContext(&cc, prm.prmCommon, prmCtx) + err := p.initCall(&cc, prm.prmCommon, prmCtx) if err != nil { return err } @@ -2732,17 +1218,13 @@ type ResGetObject struct { } // GetObject reads object header and initiates reading an object payload through a remote server using FrostFS API protocol. -// -// Main return value MUST NOT be processed on an erroneous return. func (p *Pool) GetObject(ctx context.Context, prm PrmObjectGet) (ResGetObject, error) { - p.fillAppropriateKey(&prm.prmCommon) - var cc callContext cc.sessionTarget = prm.UseSession var res ResGetObject - err := p.initCallContext(&cc, prm.prmCommon, prmContext{}) + err := p.initCall(&cc, prm.prmCommon, prmContext{}) if err != nil { return res, err } @@ -2757,17 +1239,13 @@ func (p *Pool) GetObject(ctx context.Context, prm PrmObjectGet) (ResGetObject, e } // HeadObject reads object header through a remote server using FrostFS API protocol. -// -// Main return value MUST NOT be processed on an erroneous return. func (p *Pool) HeadObject(ctx context.Context, prm PrmObjectHead) (object.Object, error) { - p.fillAppropriateKey(&prm.prmCommon) - var cc callContext cc.sessionTarget = prm.UseSession var obj object.Object - err := p.initCallContext(&cc, prm.prmCommon, prmContext{}) + err := p.initCall(&cc, prm.prmCommon, prmContext{}) if err != nil { return obj, err } @@ -2808,17 +1286,13 @@ func (x *ResObjectRange) Close() error { // ObjectRange initiates reading an object's payload range through a remote // server using FrostFS API protocol. -// -// Main return value MUST NOT be processed on an erroneous return. func (p *Pool) ObjectRange(ctx context.Context, prm PrmObjectRange) (ResObjectRange, error) { - p.fillAppropriateKey(&prm.prmCommon) - var cc callContext cc.sessionTarget = prm.UseSession var res ResObjectRange - err := p.initCallContext(&cc, prm.prmCommon, prmContext{}) + err := p.initCall(&cc, prm.prmCommon, prmContext{}) if err != nil { return res, err } @@ -2879,17 +1353,13 @@ func (x *ResObjectSearch) Close() { // // The call only opens the transmission channel, explicit fetching of matched objects // is done using the ResObjectSearch. Resulting reader must be finally closed. -// -// Main return value MUST NOT be processed on an erroneous return. func (p *Pool) SearchObjects(ctx context.Context, prm PrmObjectSearch) (ResObjectSearch, error) { - p.fillAppropriateKey(&prm.prmCommon) - var cc callContext cc.sessionTarget = prm.UseSession var res ResObjectSearch - err := p.initCallContext(&cc, prm.prmCommon, prmContext{}) + err := p.initCall(&cc, prm.prmCommon, prmContext{}) if err != nil { return res, err } @@ -2911,10 +1381,8 @@ func (p *Pool) SearchObjects(ctx context.Context, prm PrmObjectSearch) (ResObjec // waiting timeout: 120s // // Success can be verified by reading by identifier (see GetContainer). -// -// Main return value MUST NOT be processed on an erroneous return. func (p *Pool) PutContainer(ctx context.Context, prm PrmContainerPut) (cid.ID, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return cid.ID{}, err } @@ -2928,10 +1396,8 @@ func (p *Pool) PutContainer(ctx context.Context, prm PrmContainerPut) (cid.ID, e } // GetContainer reads FrostFS container by ID. -// -// Main return value MUST NOT be processed on an erroneous return. func (p *Pool) GetContainer(ctx context.Context, prm PrmContainerGet) (container.Container, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return container.Container{}, err } @@ -2946,7 +1412,7 @@ func (p *Pool) GetContainer(ctx context.Context, prm PrmContainerGet) (container // ListContainers requests identifiers of the account-owned containers. func (p *Pool) ListContainers(ctx context.Context, prm PrmContainerList) ([]cid.ID, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return nil, err } @@ -2962,7 +1428,7 @@ func (p *Pool) ListContainers(ctx context.Context, prm PrmContainerList) ([]cid. // ListContainersStream requests identifiers of the account-owned containers. func (p *Pool) ListContainersStream(ctx context.Context, prm PrmListStream) (ResListStream, error) { var res ResListStream - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return res, err } @@ -2984,7 +1450,7 @@ func (p *Pool) ListContainersStream(ctx context.Context, prm PrmListStream) (Res // // Success can be verified by reading by identifier (see GetContainer). func (p *Pool) DeleteContainer(ctx context.Context, prm PrmContainerDelete) error { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return err } @@ -2999,7 +1465,7 @@ func (p *Pool) DeleteContainer(ctx context.Context, prm PrmContainerDelete) erro // AddAPEChain sends a request to set APE chain rules for a target (basically, for a container). func (p *Pool) AddAPEChain(ctx context.Context, prm PrmAddAPEChain) error { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return err } @@ -3014,7 +1480,7 @@ func (p *Pool) AddAPEChain(ctx context.Context, prm PrmAddAPEChain) error { // RemoveAPEChain sends a request to remove APE chain rules for a target. func (p *Pool) RemoveAPEChain(ctx context.Context, prm PrmRemoveAPEChain) error { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return err } @@ -3029,7 +1495,7 @@ func (p *Pool) RemoveAPEChain(ctx context.Context, prm PrmRemoveAPEChain) error // ListAPEChains sends a request to list APE chains rules for a target. func (p *Pool) ListAPEChains(ctx context.Context, prm PrmListAPEChains) ([]ape.Chain, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return nil, err } @@ -3043,10 +1509,8 @@ func (p *Pool) ListAPEChains(ctx context.Context, prm PrmListAPEChains) ([]ape.C } // Balance requests current balance of the FrostFS account. -// -// Main return value MUST NOT be processed on an erroneous return. func (p *Pool) Balance(ctx context.Context, prm PrmBalanceGet) (accounting.Decimal, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return accounting.Decimal{}, err } @@ -3061,30 +1525,7 @@ func (p *Pool) Balance(ctx context.Context, prm PrmBalanceGet) (accounting.Decim // Statistic returns connection statistics. func (p Pool) Statistic() Statistic { - stat := Statistic{} - for _, inner := range p.innerPools { - nodes := make([]string, 0, len(inner.clients)) - inner.lock.RLock() - for _, cl := range inner.clients { - if cl.isHealthy() { - nodes = append(nodes, cl.address()) - } - node := NodeStatistic{ - address: cl.address(), - methods: cl.methodsStatus(), - overallErrors: cl.overallErrorRate(), - currentErrors: cl.currentErrorRate(), - } - stat.nodes = append(stat.nodes, node) - stat.overallErrors += node.overallErrors - } - inner.lock.RUnlock() - if len(stat.currentNodes) == 0 { - stat.currentNodes = nodes - } - } - - return stat + return p.manager.Statistic() } // waitForContainerPresence waits until the container is found on the FrostFS network. @@ -3127,10 +1568,8 @@ func waitFor(ctx context.Context, params *WaitParams, condition func(context.Con } // NetworkInfo requests information about the FrostFS network of which the remote server is a part. -// -// Main return value MUST NOT be processed on an erroneous return. func (p *Pool) NetworkInfo(ctx context.Context) (netmap.NetworkInfo, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return netmap.NetworkInfo{}, err } @@ -3144,10 +1583,8 @@ func (p *Pool) NetworkInfo(ctx context.Context) (netmap.NetworkInfo, error) { } // NetMapSnapshot requests information about the FrostFS network map. -// -// Main return value MUST NOT be processed on an erroneous return. func (p *Pool) NetMapSnapshot(ctx context.Context) (netmap.NetMap, error) { - cp, err := p.connection() + cp, err := p.manager.connection() if err != nil { return netmap.NetMap{}, err } @@ -3162,15 +1599,7 @@ func (p *Pool) NetMapSnapshot(ctx context.Context) (netmap.NetMap, error) { // Close closes the Pool and releases all the associated resources. func (p *Pool) Close() { - p.cancel() - <-p.closedCh - - // close all clients - for _, pools := range p.innerPools { - for _, cli := range pools.clients { - _ = cli.close() - } - } + p.manager.close() } // SyncContainerWithNetwork applies network configuration received via diff --git a/pool/pool_test.go b/pool/pool_test.go index 1362654..b063294 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -104,7 +104,7 @@ func TestBuildPoolOneNodeFailed(t *testing.T) { expectedAuthKey := frostfsecdsa.PublicKey(clientKeys[1].PublicKey) condition := func() bool { - cp, err := clientPool.connection() + cp, err := clientPool.manager.connection() if err != nil { return false } @@ -141,7 +141,7 @@ func TestOneNode(t *testing.T) { require.NoError(t, err) t.Cleanup(pool.Close) - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) expectedAuthKey := frostfsecdsa.PublicKey(key1.PublicKey) @@ -171,7 +171,7 @@ func TestTwoNodes(t *testing.T) { require.NoError(t, err) t.Cleanup(pool.Close) - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, assertAuthKeyForAny(st, clientKeys)) @@ -220,13 +220,12 @@ func TestOneOfTwoFailed(t *testing.T) { err = pool.Dial(context.Background()) require.NoError(t, err) - require.NoError(t, err) t.Cleanup(pool.Close) time.Sleep(2 * time.Second) for range 5 { - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, assertAuthKeyForAny(st, clientKeys)) @@ -369,7 +368,7 @@ func TestUpdateNodesHealth(t *testing.T) { tc.prepareCli(cli) p, log := newPool(t, cli) - p.updateNodesHealth(ctx, [][]float64{{1}}) + p.manager.updateNodesHealth(ctx, [][]float64{{1}}) changed := tc.wasHealthy != tc.willHealthy require.Equalf(t, tc.willHealthy, cli.isHealthy(), "healthy status should be: %v", tc.willHealthy) @@ -385,19 +384,19 @@ func newPool(t *testing.T, cli *mockClient) (*Pool, *observer.ObservedLogs) { require.NoError(t, err) return &Pool{ - innerPools: []*innerPool{{ - sampler: newSampler([]float64{1}, rand.NewSource(0)), - clients: []client{cli}, - }}, - cache: cache, - key: newPrivateKey(t), - closedCh: make(chan struct{}), - rebalanceParams: rebalanceParameters{ - nodesParams: []*nodesParam{{1, []string{"peer0"}, []float64{1}}}, - nodeRequestTimeout: time.Second, - clientRebalanceInterval: 200 * time.Millisecond, - }, - logger: log, + cache: cache, + key: newPrivateKey(t), + manager: &connectionManager{ + innerPools: []*innerPool{{ + sampler: newSampler([]float64{1}, rand.NewSource(0)), + clients: []client{cli}, + }}, + healthChecker: newHealthCheck(200 * time.Millisecond), + rebalanceParams: rebalanceParameters{ + nodesParams: []*nodesParam{{1, []string{"peer0"}, []float64{1}}}, + nodeRequestTimeout: time.Second, + }, + logger: log}, }, observedLog } @@ -435,7 +434,7 @@ func TestTwoFailed(t *testing.T) { time.Sleep(2 * time.Second) - _, err = pool.connection() + _, err = pool.manager.connection() require.Error(t, err) require.Contains(t, err.Error(), "no healthy") } @@ -469,7 +468,7 @@ func TestSessionCache(t *testing.T) { t.Cleanup(pool.Close) // cache must contain session token - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) @@ -482,7 +481,7 @@ func TestSessionCache(t *testing.T) { require.Error(t, err) // cache must not contain session token - cp, err = pool.connection() + cp, err = pool.manager.connection() require.NoError(t, err) _, ok := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.False(t, ok) @@ -494,7 +493,7 @@ func TestSessionCache(t *testing.T) { require.NoError(t, err) // cache must contain session token - cp, err = pool.connection() + cp, err = pool.manager.connection() require.NoError(t, err) st, _ = pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) @@ -538,7 +537,7 @@ func TestPriority(t *testing.T) { expectedAuthKey1 := frostfsecdsa.PublicKey(clientKeys[0].PublicKey) firstNode := func() bool { - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) return st.AssertAuthKey(&expectedAuthKey1) @@ -546,7 +545,7 @@ func TestPriority(t *testing.T) { expectedAuthKey2 := frostfsecdsa.PublicKey(clientKeys[1].PublicKey) secondNode := func() bool { - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) return st.AssertAuthKey(&expectedAuthKey2) @@ -583,7 +582,7 @@ func TestSessionCacheWithKey(t *testing.T) { require.NoError(t, err) // cache must contain session token - cp, err := pool.connection() + cp, err := pool.manager.connection() require.NoError(t, err) st, _ := pool.cache.Get(formCacheKey(cp.address(), pool.key, false)) require.True(t, st.AssertAuthKey(&expectedAuthKey)) @@ -636,9 +635,8 @@ func TestSessionTokenOwner(t *testing.T) { cc.sessionTarget = func(tok session.Object) { tkn = tok } - err = p.initCallContext(&cc, prm, prmCtx) + err = p.initCall(&cc, prm, prmCtx) require.NoError(t, err) - err = p.openDefaultSession(ctx, &cc) require.NoError(t, err) require.True(t, tkn.VerifySignature()) @@ -922,14 +920,14 @@ func TestSwitchAfterErrorThreshold(t *testing.T) { t.Cleanup(pool.Close) for range errorThreshold { - conn, err := pool.connection() + conn, err := pool.manager.connection() require.NoError(t, err) require.Equal(t, nodes[0].address, conn.address()) _, err = conn.objectGet(ctx, PrmObjectGet{}) require.Error(t, err) } - conn, err := pool.connection() + conn, err := pool.manager.connection() require.NoError(t, err) require.Equal(t, nodes[1].address, conn.address()) _, err = conn.objectGet(ctx, PrmObjectGet{}) diff --git a/pool/sampler_test.go b/pool/sampler_test.go index ab06e0f..b0860b1 100644 --- a/pool/sampler_test.go +++ b/pool/sampler_test.go @@ -47,9 +47,6 @@ func TestHealthyReweight(t *testing.T) { buffer = make([]float64, len(weights)) ) - cache, err := newCache(0) - require.NoError(t, err) - client1 := newMockClient(names[0], *newPrivateKey(t)) client1.errOnDial() @@ -59,22 +56,20 @@ func TestHealthyReweight(t *testing.T) { sampler: newSampler(weights, rand.NewSource(0)), clients: []client{client1, client2}, } - p := &Pool{ + cm := &connectionManager{ innerPools: []*innerPool{inner}, - cache: cache, - key: newPrivateKey(t), rebalanceParams: rebalanceParameters{nodesParams: []*nodesParam{{weights: weights}}}, } // check getting first node connection before rebalance happened - connection0, err := p.connection() + connection0, err := cm.connection() require.NoError(t, err) mock0 := connection0.(*mockClient) require.Equal(t, names[0], mock0.address()) - p.updateInnerNodesHealth(context.TODO(), 0, buffer) + cm.updateInnerNodesHealth(context.TODO(), 0, buffer) - connection1, err := p.connection() + connection1, err := cm.connection() require.NoError(t, err) mock1 := connection1.(*mockClient) require.Equal(t, names[1], mock1.address()) @@ -84,10 +79,10 @@ func TestHealthyReweight(t *testing.T) { inner.clients[0] = newMockClient(names[0], *newPrivateKey(t)) inner.lock.Unlock() - p.updateInnerNodesHealth(context.TODO(), 0, buffer) + cm.updateInnerNodesHealth(context.TODO(), 0, buffer) inner.sampler = newSampler(weights, rand.NewSource(0)) - connection0, err = p.connection() + connection0, err = cm.connection() require.NoError(t, err) mock0 = connection0.(*mockClient) require.Equal(t, names[0], mock0.address()) @@ -108,12 +103,12 @@ func TestHealthyNoReweight(t *testing.T) { newMockClient(names[1], *newPrivateKey(t)), }, } - p := &Pool{ + cm := &connectionManager{ innerPools: []*innerPool{inner}, rebalanceParams: rebalanceParameters{nodesParams: []*nodesParam{{weights: weights}}}, } - p.updateInnerNodesHealth(context.TODO(), 0, buffer) + cm.updateInnerNodesHealth(context.TODO(), 0, buffer) inner.lock.RLock() defer inner.lock.RUnlock()