diff --git a/hrw.go b/hrw.go index 8ceae9e..57f49fd 100644 --- a/hrw.go +++ b/hrw.go @@ -243,24 +243,6 @@ func ValidateWeights(weights []float64) error { return nil } -func newSorter(l int, byIndex bool, nodes []uint64, h uint64, - swap func(i, j int)) (*sorter, []int, []uint64) { - ind := make([]int, l) - dist := make([]uint64, l) - for i := 0; i < l; i++ { - ind[i] = i - dist[i] = getDistance(byIndex, i, nodes, h) - } - - return &sorter{ - l: l, - swap: func(i, j int) { - swap(i, j) - ind[i], ind[j] = ind[j], ind[i] - }, - }, ind, dist -} - // sortByWeight sorts nodes by weight using provided swapper. // nodes contains hrw hashes. If it is nil, indices are used. func sortByWeight(l int, byIndex bool, nodes []uint64, weights []float64, hash uint64, swap func(i, j int)) { @@ -270,14 +252,23 @@ func sortByWeight(l int, byIndex bool, nodes []uint64, weights []float64, hash u return } - s, ind, dist := newSorter(l, byIndex, nodes, hash, swap) - s.less = func(i, j int) bool { - ii, jj := ind[i], ind[j] + dist := make([]float64, l) + for i := 0; i < l; i++ { + d := getDistance(byIndex, i, nodes, hash) // `maxUint64 - distance` makes the shorter distance more valuable // it is necessary for operation with normalized values - wi := float64(^uint64(0)-dist[ii]) * weights[ii] - wj := float64(^uint64(0)-dist[jj]) * weights[jj] - return wi > wj // higher distance must be placed lower to be first + dist[i] = float64(^uint64(0)-d) * weights[i] + } + + s := &sorter{ + l: l, + swap: func(i, j int) { + swap(i, j) + dist[i], dist[j] = dist[j], dist[i] + }, + less: func(i, j int) bool { + return dist[i] > dist[j] // higher distance must be placed lower to be first + }, } sort.Sort(s) } @@ -285,9 +276,20 @@ func sortByWeight(l int, byIndex bool, nodes []uint64, weights []float64, hash u // sortByDistance sorts nodes by hrw distance using provided swapper. // nodes contains hrw hashes. If it is nil, indices are used. func sortByDistance(l int, byIndex bool, nodes []uint64, hash uint64, swap func(i, j int)) { - s, ind, dist := newSorter(l, byIndex, nodes, hash, swap) - s.less = func(i, j int) bool { - return dist[ind[i]] < dist[ind[j]] + dist := make([]uint64, l) + for i := 0; i < l; i++ { + dist[i] = getDistance(byIndex, i, nodes, hash) + } + + s := &sorter{ + l: l, + swap: func(i, j int) { + swap(i, j) + dist[i], dist[j] = dist[j], dist[i] + }, + less: func(i, j int) bool { + return dist[i] < dist[j] + }, } sort.Sort(s) }