diff --git a/hrw.go b/hrw.go index c874917..9059636 100644 --- a/hrw.go +++ b/hrw.go @@ -21,6 +21,12 @@ type ( less func(i, j int) bool swap func(i, j int) } + + hasherSorter[T Hasher, N interface{ ~uint64 | ~float64 }] struct { + slice []T + dist []N + asc bool + } ) // Boundaries of valid normalized weights @@ -33,6 +39,18 @@ func (s *sorter) Len() int { return s.l } func (s *sorter) Less(i, j int) bool { return s.less(i, j) } func (s *sorter) Swap(i, j int) { s.swap(i, j) } +func (s *hasherSorter[T, N]) Len() int { return len(s.slice) } +func (s *hasherSorter[T, N]) Less(i, j int) bool { + if s.asc { + return s.dist[i] < s.dist[j] + } + return s.dist[i] > s.dist[j] +} +func (s *hasherSorter[T, N]) Swap(i, j int) { + s.slice[i], s.slice[j] = s.slice[j], s.slice[i] + s.dist[i], s.dist[j] = s.dist[j], s.dist[i] +} + func distance(x uint64, y uint64) uint64 { acc := x ^ y // here used mmh3 64 bit finalizer @@ -128,29 +146,19 @@ func SortHasherSliceByWeightValue[T Hasher](slice []T, weights []float64, hash u dist[i] = float64(^uint64(0)-d) * weights[i] } - sort.Sort(&sorter{ - l: len(slice), - swap: func(i, j int) { - slice[i], slice[j] = slice[j], slice[i] - dist[i], dist[j] = dist[j], dist[i] - }, - less: func(i, j int) bool { - return dist[i] > dist[j] // higher distance must be placed lower to be first - }, + sort.Sort(&hasherSorter[T, float64]{ + slice: slice, + dist: dist, + asc: false, }) } // sortHasherByDistance is similar to sortByDistance but accepts slice directly. func sortHasherByDistance[T Hasher](slice []T, byIndex bool, dist []uint64) { - sort.Sort(&sorter{ - l: len(slice), - swap: func(i, j int) { - slice[i], slice[j] = slice[j], slice[i] - dist[i], dist[j] = dist[j], dist[i] - }, - less: func(i, j int) bool { - return dist[i] < dist[j] - }, + sort.Sort(&hasherSorter[T, uint64]{ + slice: slice, + dist: dist, + asc: true, }) }