package statesync import ( "bytes" "sort" "sync" "github.com/nspcc-dev/neo-go/pkg/util" ) // Pool stores unknown MPT nodes along with the corresponding paths (single node is // allowed to have multiple MPT paths). type Pool struct { lock sync.RWMutex hashes map[util.Uint256][][]byte } // NewPool returns new MPT node hashes pool. func NewPool() *Pool { return &Pool{ hashes: make(map[util.Uint256][][]byte), } } // ContainsKey checks if MPT node with the specified hash is in the Pool. func (mp *Pool) ContainsKey(hash util.Uint256) bool { mp.lock.RLock() defer mp.lock.RUnlock() _, ok := mp.hashes[hash] return ok } // TryGet returns a set of MPT paths for the specified HashNode. func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) { mp.lock.RLock() defer mp.lock.RUnlock() paths, ok := mp.hashes[hash] return paths, ok } // GetAll returns all MPT nodes with the corresponding paths from the pool. func (mp *Pool) GetAll() map[util.Uint256][][]byte { mp.lock.RLock() defer mp.lock.RUnlock() return mp.hashes } // GetBatch returns set of unknown MPT nodes hashes (`limit` at max). func (mp *Pool) GetBatch(limit int) []util.Uint256 { mp.lock.RLock() defer mp.lock.RUnlock() count := len(mp.hashes) if count > limit { count = limit } result := make([]util.Uint256, 0, limit) for h := range mp.hashes { if count == 0 { break } result = append(result, h) count-- } return result } // Remove removes MPT node from the pool by the specified hash. func (mp *Pool) Remove(hash util.Uint256) { mp.lock.Lock() defer mp.lock.Unlock() delete(mp.hashes, hash) } // Add adds path to the set of paths for the specified node. func (mp *Pool) Add(hash util.Uint256, path []byte) { mp.lock.Lock() defer mp.lock.Unlock() mp.addPaths(hash, [][]byte{path}) } // Update is an atomic operation and removes/adds specified nodes from/to the pool. func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) { mp.lock.Lock() defer mp.lock.Unlock() for h, paths := range remove { old := mp.hashes[h] for _, path := range paths { i := sort.Search(len(old), func(i int) bool { return bytes.Compare(old[i], path) >= 0 }) if i < len(old) && bytes.Equal(old[i], path) { old = append(old[:i], old[i+1:]...) } } if len(old) == 0 { delete(mp.hashes, h) } else { mp.hashes[h] = old } } for h, paths := range add { mp.addPaths(h, paths) } } // addPaths adds set of the specified node paths to the pool. func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) { old := mp.hashes[nodeHash] for _, path := range paths { i := sort.Search(len(old), func(i int) bool { return bytes.Compare(old[i], path) >= 0 }) if i < len(old) && bytes.Equal(old[i], path) { // then path is already added continue } old = append(old, path) if i != len(old)-1 { copy(old[i+1:], old[i:]) old[i] = path } } mp.hashes[nodeHash] = old } // Count returns the number of nodes in the pool. func (mp *Pool) Count() int { mp.lock.RLock() defer mp.lock.RUnlock() return len(mp.hashes) }