diff --git a/hrw.go b/hrw.go index f946dda..c1a990f 100644 --- a/hrw.go +++ b/hrw.go @@ -5,6 +5,7 @@ package hrw import ( "encoding/binary" "errors" + "math" "reflect" "sort" @@ -293,7 +294,7 @@ func prepareRule(slice interface{}) []uint64 { // ValidateWeights checks if weights are normalized between 0.0 and 1.0 func ValidateWeights(weights []float64) error { for i := range weights { - if weights[i] > NormalizedMaxWeight || weights[i] < NormalizedMinWeight { + if math.IsNaN(weights[i]) || weights[i] > NormalizedMaxWeight || weights[i] < NormalizedMinWeight { return errors.New("weights are not normalized") } } diff --git a/hrw_test.go b/hrw_test.go index 389d88d..fb7a34d 100644 --- a/hrw_test.go +++ b/hrw_test.go @@ -76,6 +76,9 @@ func TestValidateWeights(t *testing.T) { weights := []float64{10, 10, 10, 2, 2, 2} err := ValidateWeights(weights) require.Error(t, err) + weights = []float64{math.NaN(), 1, 1, 0.2, 0.2, 0.2} + err = ValidateWeights(weights) + require.Error(t, err) weights = []float64{1, 1, 1, 0.2, 0.2, 0.2} err = ValidateWeights(weights) require.NoError(t, err)