diff --git a/plugin/loadbalance/weighted.go b/plugin/loadbalance/weighted.go index be49b18b1..bdd8232f8 100644 --- a/plugin/loadbalance/weighted.go +++ b/plugin/loadbalance/weighted.go @@ -254,26 +254,26 @@ func (w *weightedRR) updateWeights() error { scanner := bufio.NewScanner(&buf) // Parse the weight file contents - err = w.parseWeights(scanner) + domains, err := w.parseWeights(scanner) if err != nil { return err } + // access to weights must be protected + w.mutex.Lock() + w.domains = domains + w.mutex.Unlock() + log.Infof("Successfully reloaded weight file %s", w.fileName) return nil } // Parse the weight file contents -func (w *weightedRR) parseWeights(scanner *bufio.Scanner) error { - // access to weights must be protected - w.mutex.Lock() - defer w.mutex.Unlock() - - // Reset domains - w.domains = make(map[string]weights) - +func (w *weightedRR) parseWeights(scanner *bufio.Scanner) (map[string]weights, error) { var dname string var ws weights + domains := make(map[string]weights) + for scanner.Scan() { nextLine := strings.TrimSpace(scanner.Text()) if len(nextLine) == 0 || nextLine[0:1] == "#" { @@ -285,7 +285,7 @@ func (w *weightedRR) parseWeights(scanner *bufio.Scanner) error { case 1: // (domain) name sanity check if net.ParseIP(fields[0]) != nil { - return fmt.Errorf("Wrong domain name:\"%s\" in weight file %s. (Maybe a missing weight value?)", + return nil, fmt.Errorf("Wrong domain name:\"%s\" in weight file %s. (Maybe a missing weight value?)", fields[0], w.fileName) } dname = fields[0] @@ -295,35 +295,35 @@ func (w *weightedRR) parseWeights(scanner *bufio.Scanner) error { dname += "." } var ok bool - ws, ok = w.domains[dname] + ws, ok = domains[dname] if !ok { ws = make(weights, 0) - w.domains[dname] = ws + domains[dname] = ws } case 2: // IP address and weight value ip := net.ParseIP(fields[0]) if ip == nil { - return fmt.Errorf("Wrong IP address:\"%s\" in weight file %s", fields[0], w.fileName) + return nil, fmt.Errorf("Wrong IP address:\"%s\" in weight file %s", fields[0], w.fileName) } weight, err := strconv.ParseUint(fields[1], 10, 8) - if err != nil { - return fmt.Errorf("Wrong weight value:\"%s\" in weight file %s", fields[1], w.fileName) + if err != nil || weight == 0 { + return nil, fmt.Errorf("Wrong weight value:\"%s\" in weight file %s", fields[1], w.fileName) } witem := &weightItem{address: ip, value: uint8(weight)} if dname == "" { - return fmt.Errorf("Missing domain name in weight file %s", w.fileName) + return nil, fmt.Errorf("Missing domain name in weight file %s", w.fileName) } ws = append(ws, witem) - w.domains[dname] = ws + domains[dname] = ws default: - return fmt.Errorf("Could not parse weight line:\"%s\" in weight file %s", nextLine, w.fileName) + return nil, fmt.Errorf("Could not parse weight line:\"%s\" in weight file %s", nextLine, w.fileName) } } if err := scanner.Err(); err != nil { - return fmt.Errorf("Weight file %s parsing error:%s", w.fileName, err) + return nil, fmt.Errorf("Weight file %s parsing error:%s", w.fileName, err) } - return nil + return domains, nil } diff --git a/plugin/loadbalance/weighted_test.go b/plugin/loadbalance/weighted_test.go index e502c2772..fe8f5950e 100644 --- a/plugin/loadbalance/weighted_test.go +++ b/plugin/loadbalance/weighted_test.go @@ -79,6 +79,11 @@ w1,example.org 192.168.1.14 300 ` +const zeroWeightWRR = ` +w1,example.org +192.168.1.14 0 +` + func TestWeightFileUpdate(t *testing.T) { tests := []struct { weightFilContent string @@ -95,6 +100,7 @@ func TestWeightFileUpdate(t *testing.T) { {missingDomainWRR, true, nil, "Missing domain name"}, {wrongIpWRR, true, nil, "Wrong IP address"}, {wrongWeightWRR, true, nil, "Wrong weight value"}, + {zeroWeightWRR, true, nil, "Wrong weight value"}, } for i, test := range tests {