[#222] auditsvc: Refactor audit task

Resolve containedctx linter. Cancel task by listen cancel.

Signed-off-by: Dmitrii Stepanov <d.stepanov@yadro.com>
This commit is contained in:
Dmitrii Stepanov 2023-04-06 11:36:25 +03:00
parent 3dbff0a478
commit e8d340287f
10 changed files with 66 additions and 58 deletions

View file

@ -46,15 +46,21 @@ func (ap *Processor) processStartAudit(epoch uint64) {
return return
} }
var auditCtx context.Context cancelChannel := make(chan struct{})
auditCtx, ap.prevAuditCanceler = context.WithCancel(context.Background()) ap.prevAuditCanceler = func() {
select {
case <-cancelChannel: // already closed
default:
close(cancelChannel)
}
}
pivot := make([]byte, sha256.Size) pivot := make([]byte, sha256.Size)
ap.startAuditTasksOnContainers(auditCtx, containers, log, pivot, nm, epoch) ap.startAuditTasksOnContainers(cancelChannel, containers, log, pivot, nm, epoch)
} }
func (ap *Processor) startAuditTasksOnContainers(ctx context.Context, containers []cid.ID, log *zap.Logger, pivot []byte, nm *netmap.NetMap, epoch uint64) { func (ap *Processor) startAuditTasksOnContainers(cancelChannel <-chan struct{}, containers []cid.ID, log *zap.Logger, pivot []byte, nm *netmap.NetMap, epoch uint64) {
for i := range containers { for i := range containers {
cnr, err := cntClient.Get(ap.containerClient, containers[i]) // get container structure cnr, err := cntClient.Get(ap.containerClient, containers[i]) // get container structure
if err != nil { if err != nil {
@ -107,18 +113,14 @@ func (ap *Processor) startAuditTasksOnContainers(ctx context.Context, containers
epoch: epoch, epoch: epoch,
rep: ap.reporter, rep: ap.reporter,
}). }).
WithAuditContext(ctx). WithCancelChannel(cancelChannel).
WithContainerID(containers[i]). WithContainerID(containers[i]).
WithStorageGroupList(storageGroups). WithStorageGroupList(storageGroups).
WithContainerStructure(cnr.Value). WithContainerStructure(cnr.Value).
WithContainerNodes(nodes). WithContainerNodes(nodes).
WithNetworkMap(nm) WithNetworkMap(nm)
if err := ap.taskManager.PushTask(auditTask); err != nil { ap.taskManager.PushTask(auditTask)
ap.log.Error("could not push audit task",
zap.String("error", err.Error()),
)
}
} }
} }

View file

@ -24,7 +24,7 @@ type (
} }
TaskManager interface { TaskManager interface {
PushTask(*audit.Task) error PushTask(*audit.Task)
// Must skip all tasks planned for execution and // Must skip all tasks planned for execution and
// return their number. // return their number.

View file

@ -194,9 +194,7 @@ func (c *Context) init() {
)} )}
} }
func (c *Context) expired() bool { func (c *Context) expired(ctx context.Context) bool {
ctx := c.task.AuditContext()
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.log.Debug("audit context is done", c.log.Debug("audit context is done",

View file

@ -1,16 +1,18 @@
package auditor package auditor
import ( import (
"context"
"fmt" "fmt"
) )
// Execute audits container data. // Execute audits container data.
func (c *Context) Execute() { func (c *Context) Execute(ctx context.Context, onCompleted func()) {
defer onCompleted()
c.init() c.init()
checks := []struct { checks := []struct {
name string name string
exec func() exec func(context.Context)
}{ }{
{name: "PoR", exec: c.executePoR}, {name: "PoR", exec: c.executePoR},
{name: "PoP", exec: c.executePoP}, {name: "PoP", exec: c.executePoP},
@ -20,11 +22,11 @@ func (c *Context) Execute() {
for i := range checks { for i := range checks {
c.log.Debug(fmt.Sprintf("executing %s check...", checks[i].name)) c.log.Debug(fmt.Sprintf("executing %s check...", checks[i].name))
if c.expired() { if c.expired(ctx) {
break break
} }
checks[i].exec() checks[i].exec(ctx)
if i == len(checks)-1 { if i == len(checks)-1 {
c.complete() c.complete()

View file

@ -2,6 +2,7 @@ package auditor
import ( import (
"bytes" "bytes"
"context"
"sync" "sync"
"time" "time"
@ -13,12 +14,12 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
func (c *Context) executePDP() { func (c *Context) executePDP(ctx context.Context) {
c.processPairs() c.processPairs(ctx)
c.writePairsResult() c.writePairsResult()
} }
func (c *Context) processPairs() { func (c *Context) processPairs(ctx context.Context) {
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
for i := range c.pairs { for i := range c.pairs {
@ -26,7 +27,7 @@ func (c *Context) processPairs() {
wg.Add(1) wg.Add(1)
if err := c.pdpWorkerPool.Submit(func() { if err := c.pdpWorkerPool.Submit(func() {
c.processPair(p) c.processPair(ctx, p)
wg.Done() wg.Done()
}); err != nil { }); err != nil {
wg.Done() wg.Done()
@ -37,9 +38,9 @@ func (c *Context) processPairs() {
c.pdpWorkerPool.Release() c.pdpWorkerPool.Release()
} }
func (c *Context) processPair(p *gamePair) { func (c *Context) processPair(ctx context.Context, p *gamePair) {
c.distributeRanges(p) c.distributeRanges(p)
c.collectHashes(p) c.collectHashes(ctx, p)
c.analyzeHashes(p) c.analyzeHashes(p)
} }
@ -106,7 +107,7 @@ func (c *Context) splitPayload(id oid.ID) []uint64 {
return notches return notches
} }
func (c *Context) collectHashes(p *gamePair) { func (c *Context) collectHashes(ctx context.Context, p *gamePair) {
fn := func(n netmap.NodeInfo, rngs []*object.Range) [][]byte { fn := func(n netmap.NodeInfo, rngs []*object.Range) [][]byte {
// Here we randomize the order a bit: the hypothesis is that this // Here we randomize the order a bit: the hypothesis is that this
// makes it harder for an unscrupulous node to come up with a // makes it harder for an unscrupulous node to come up with a
@ -137,7 +138,7 @@ func (c *Context) collectHashes(p *gamePair) {
getRangeHashPrm.Range = rngs[i] getRangeHashPrm.Range = rngs[i]
h, err := c.cnrCom.GetRangeHash(c.task.AuditContext(), getRangeHashPrm) h, err := c.cnrCom.GetRangeHash(ctx, getRangeHashPrm)
if err != nil { if err != nil {
c.log.Debug("could not get payload range hash", c.log.Debug("could not get payload range hash",
zap.Stringer("id", p.id), zap.Stringer("id", p.id),

View file

@ -1,6 +1,8 @@
package auditor package auditor
import ( import (
"context"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/netmap"
oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id" oid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/id"
"git.frostfs.info/TrueCloudLab/tzhash/tz" "git.frostfs.info/TrueCloudLab/tzhash/tz"
@ -12,8 +14,8 @@ const (
minGamePayloadSize = hashRangeNumber * tz.Size minGamePayloadSize = hashRangeNumber * tz.Size
) )
func (c *Context) executePoP() { func (c *Context) executePoP(ctx context.Context) {
c.buildCoverage() c.buildCoverage(ctx)
c.report.SetPlacementCounters( c.report.SetPlacementCounters(
c.counters.hit, c.counters.hit,
@ -22,13 +24,13 @@ func (c *Context) executePoP() {
) )
} }
func (c *Context) buildCoverage() { func (c *Context) buildCoverage(ctx context.Context) {
policy := c.task.ContainerStructure().PlacementPolicy() policy := c.task.ContainerStructure().PlacementPolicy()
// select random member from another storage group // select random member from another storage group
// and process all placement vectors // and process all placement vectors
c.iterateSGMembersPlacementRand(func(id oid.ID, ind int, nodes []netmap.NodeInfo) bool { c.iterateSGMembersPlacementRand(func(id oid.ID, ind int, nodes []netmap.NodeInfo) bool {
c.processObjectPlacement(id, nodes, policy.ReplicaNumberByIndex(ind)) c.processObjectPlacement(ctx, id, nodes, policy.ReplicaNumberByIndex(ind))
return c.containerCovered() return c.containerCovered()
}) })
} }
@ -38,7 +40,7 @@ func (c *Context) containerCovered() bool {
return c.cnrNodesNum <= len(c.pairedNodes) return c.cnrNodesNum <= len(c.pairedNodes)
} }
func (c *Context) processObjectPlacement(id oid.ID, nodes []netmap.NodeInfo, replicas uint32) { func (c *Context) processObjectPlacement(ctx context.Context, id oid.ID, nodes []netmap.NodeInfo, replicas uint32) {
var ( var (
ok uint32 ok uint32
optimal bool optimal bool
@ -57,7 +59,7 @@ func (c *Context) processObjectPlacement(id oid.ID, nodes []netmap.NodeInfo, rep
getHeaderPrm.Node = nodes[i] getHeaderPrm.Node = nodes[i]
// try to get object header from node // try to get object header from node
hdr, err := c.cnrCom.GetHeader(c.task.AuditContext(), getHeaderPrm) hdr, err := c.cnrCom.GetHeader(ctx, getHeaderPrm)
if err != nil { if err != nil {
c.log.Debug("could not get object header from candidate", c.log.Debug("could not get object header from candidate",
zap.Stringer("id", id), zap.Stringer("id", id),

View file

@ -2,6 +2,7 @@ package auditor
import ( import (
"bytes" "bytes"
"context"
"sync" "sync"
"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object_manager/placement" "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/services/object_manager/placement"
@ -14,7 +15,7 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
func (c *Context) executePoR() { func (c *Context) executePoR(ctx context.Context) {
wg := new(sync.WaitGroup) wg := new(sync.WaitGroup)
sgs := c.task.StorageGroupList() sgs := c.task.StorageGroupList()
@ -22,7 +23,7 @@ func (c *Context) executePoR() {
wg.Add(1) wg.Add(1)
if err := c.porWorkerPool.Submit(func() { if err := c.porWorkerPool.Submit(func() {
c.checkStorageGroupPoR(sg.ID(), sg.StorageGroup()) c.checkStorageGroupPoR(ctx, sg.ID(), sg.StorageGroup())
wg.Done() wg.Done()
}); err != nil { }); err != nil {
wg.Done() wg.Done()
@ -36,7 +37,7 @@ func (c *Context) executePoR() {
} }
// nolint: funlen // nolint: funlen
func (c *Context) checkStorageGroupPoR(sgID oid.ID, sg storagegroupSDK.StorageGroup) { func (c *Context) checkStorageGroupPoR(ctx context.Context, sgID oid.ID, sg storagegroupSDK.StorageGroup) {
members := sg.Members() members := sg.Members()
c.updateSGInfo(sgID, members) c.updateSGInfo(sgID, members)
@ -80,7 +81,7 @@ func (c *Context) checkStorageGroupPoR(sgID oid.ID, sg storagegroupSDK.StorageGr
getHeaderPrm.Node = flat[j] getHeaderPrm.Node = flat[j]
hdr, err := c.cnrCom.GetHeader(c.task.AuditContext(), getHeaderPrm) hdr, err := c.cnrCom.GetHeader(ctx, getHeaderPrm)
if err != nil { if err != nil {
c.log.Debug("can't head object", c.log.Debug("can't head object",
zap.String("remote_node", netmap.StringifyPublicKey(flat[j])), zap.String("remote_node", netmap.StringifyPublicKey(flat[j])),

View file

@ -1,8 +1,6 @@
package audit package audit
import ( import (
"context"
"git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/storagegroup" "git.frostfs.info/TrueCloudLab/frostfs-node/pkg/core/storagegroup"
"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container" "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container"
cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id" cid "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id"
@ -10,11 +8,10 @@ import (
) )
// Task groups groups the container audit parameters. // Task groups groups the container audit parameters.
// nolint: containedctx
type Task struct { type Task struct {
reporter Reporter cancelCh <-chan struct{}
auditContext context.Context reporter Reporter
idCnr cid.ID idCnr cid.ID
@ -41,18 +38,15 @@ func (t *Task) Reporter() Reporter {
return t.reporter return t.reporter
} }
// WithAuditContext sets context of the audit of the current epoch. func (t *Task) WithCancelChannel(ch <-chan struct{}) *Task {
func (t *Task) WithAuditContext(ctx context.Context) *Task { if ch != nil {
if t != nil { t.cancelCh = ch
t.auditContext = ctx
} }
return t return t
} }
// AuditContext returns context of the audit of the current epoch. func (t *Task) CancelChannel() <-chan struct{} {
func (t *Task) AuditContext() context.Context { return t.cancelCh
return t.auditContext
} }
// WithContainerID sets identifier of the container under audit. // WithContainerID sets identifier of the container under audit.

View file

@ -33,18 +33,28 @@ func (m *Manager) Listen(ctx context.Context) {
return return
} }
m.handleTask(task) tCtx, tCancel := context.WithCancel(ctx) // cancel task in case of listen cancel
go func() {
select {
case <-tCtx.Done(): // listen cancelled or task completed
return
case <-task.CancelChannel(): // new epoch
tCancel()
}
}()
m.handleTask(tCtx, task, tCancel)
} }
} }
} }
func (m *Manager) handleTask(task *audit.Task) { func (m *Manager) handleTask(ctx context.Context, task *audit.Task, onCompleted func()) {
pdpPool, err := m.pdpPoolGenerator() pdpPool, err := m.pdpPoolGenerator()
if err != nil { if err != nil {
m.log.Error("could not generate PDP worker pool", m.log.Error("could not generate PDP worker pool",
zap.String("error", err.Error()), zap.String("error", err.Error()),
) )
onCompleted()
return return
} }
@ -53,7 +63,7 @@ func (m *Manager) handleTask(task *audit.Task) {
m.log.Error("could not generate PoR worker pool", m.log.Error("could not generate PoR worker pool",
zap.String("error", err.Error()), zap.String("error", err.Error()),
) )
onCompleted()
return return
} }
@ -61,9 +71,10 @@ func (m *Manager) handleTask(task *audit.Task) {
WithPDPWorkerPool(pdpPool). WithPDPWorkerPool(pdpPool).
WithPoRWorkerPool(porPool) WithPoRWorkerPool(porPool)
if err := m.workerPool.Submit(auditContext.Execute); err != nil { if err := m.workerPool.Submit(func() { auditContext.Execute(ctx, onCompleted) }); err != nil {
// may be we should report it // may be we should report it
m.log.Warn("could not submit audit task") m.log.Warn("could not submit audit task")
onCompleted()
} }
} }

View file

@ -5,9 +5,6 @@ import (
) )
// PushTask adds a task to the queue for processing. // PushTask adds a task to the queue for processing.
// func (m *Manager) PushTask(t *audit.Task) {
// Returns error if task was not added to the queue.
func (m *Manager) PushTask(t *audit.Task) error {
m.ch <- t m.ch <- t
return nil
} }