Use source IPs instead of interfaces #11

Merged
dstepanov-yadro merged 8 commits from dstepanov-yadro/multinet:fix/use_source_ips into master 2024-11-02 14:21:56 +00:00
25 changed files with 512 additions and 811 deletions

View file

@ -0,0 +1,21 @@
name: DCO action
on: [pull_request]
jobs:
dco:
name: DCO
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
- name: Run commit format checker
uses: https://git.frostfs.info/TrueCloudLab/dco-go@v3
with:
from: 'origin/${{ github.event.pull_request.base.ref }}'

View file

@ -0,0 +1,25 @@
name: Pre-commit hooks
on: [pull_request]
jobs:
precommit:
name: Pre-commit
env:
# Skip pre-commit hooks which are executed by other actions.
SKIP: make-lint,go-staticcheck-repo-mod,go-unit-tests,gofumpt
runs-on: ubuntu-22.04
# If we use actions/setup-python from either Github or Gitea,
# the line above fails with a cryptic error about not being able to find python.
# So install everything manually.
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.23
- name: Set up Python
run: |
apt update
apt install -y pre-commit
- name: Run pre-commit
run: pre-commit run --color=always --hook-stage manual --all-files

View file

@ -0,0 +1,111 @@
name: Tests and linters
on: [pull_request]
jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
cache: true
- name: Install linters
run: make lint-install
- name: Run linters
run: make lint
tests:
name: Tests
runs-on: ubuntu-latest
strategy:
matrix:
go_versions: [ '1.22', '1.23' ]
fail-fast: false
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '${{ matrix.go_versions }}'
cache: true
- name: Run tests
run: make test
tests-race:
name: Tests with -race
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
cache: true
- name: Run tests
run: go test ./... -count=1 -race
staticcheck:
name: Staticcheck
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
cache: true
- name: Install staticcheck
run: make staticcheck-install
- name: Run staticcheck
run: make staticcheck-run
gopls:
name: gopls check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
cache: true
- name: Install gopls
run: make gopls-install
- name: Run gopls
run: make gopls-run
fumpt:
name: Run gofumpt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
cache: true
- name: Install gofumpt
run: make fumpt-install
- name: Run gofumpt
run: |
make fumpt
git diff --exit-code --quiet

View file

@ -0,0 +1,22 @@
name: Vulncheck
on: [pull_request]
jobs:
vulncheck:
name: Vulncheck
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
- name: Install govulncheck
run: go install golang.org/x/vuln/cmd/govulncheck@latest
- name: Run govulncheck
run: govulncheck ./...

1
.gitattributes vendored Normal file
View file

@ -0,0 +1 @@
/go.sum -diff

22
.gitignore vendored Normal file
View file

@ -0,0 +1,22 @@
# IDE
.idea
.vscode
# Vendoring
vendor
# tempfiles
.DS_Store
*~
.cache
temp
tmp
# binary
bin/
release/
# coverage
coverage.txt
coverage.html

75
.golangci.yml Normal file
View file

@ -0,0 +1,75 @@
# This file contains all available configuration options
# with their default values.
# options for analysis running
run:
# timeout for analysis, e.g. 30s, 5m, default is 1m
timeout: 20m
# include test files or not, default is true
tests: false
# output configuration options
output:
# colored-line-number|line-number|json|tab|checkstyle|code-climate, default is "colored-line-number"
formats:
- format: tab
# all available settings of specific linters
linters-settings:
exhaustive:
# indicates that switch statements are to be considered exhaustive if a
# 'default' case is present, even if all enum members aren't listed in the
# switch
default-signifies-exhaustive: true
govet:
# report about shadowed variables
check-shadowing: false
staticcheck:
checks: ["all", "-SA1019"] # TODO Enable SA1019 after deprecated warning are fixed.
funlen:
lines: 80 # default 60
statements: 60 # default 40
gocognit:
min-complexity: 40 # default 30
unused:
field-writes-are-uses: false
exported-fields-are-used: false
local-variables-are-used: false
linters:
enable:
# mandatory linters
- govet
- revive
# some default golangci-lint linters
- errcheck
- gosimple
- godot
- ineffassign
- staticcheck
- typecheck
- unused
# extra linters
- bidichk
- durationcheck
- exhaustive
- copyloopvar
- gofmt
- goimports
- misspell
- predeclared
- reassign
- whitespace
- containedctx
- funlen
- gocognit
- contextcheck
- importas
- perfsprint
- testifylint
- protogetter
disable-all: true
fast: false

56
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,56 @@
ci:
autofix_prs: false
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
- id: check-executables-have-shebangs
- id: check-shebang-scripts-are-executable
- id: check-merge-conflict
- id: check-json
- id: check-xml
- id: check-yaml
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: end-of-file-fixer
exclude: "(.key|.svg)$"
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.9.0.6
hooks:
- id: shellcheck
- repo: local
hooks:
- id: make-lint
name: Run Make Lint
entry: make lint
language: system
pass_filenames: false
- repo: local
hooks:
- id: go-unit-tests
name: go unit tests
entry: make test GOFLAGS=''
pass_filenames: false
types: [go]
language: system
- repo: local
hooks:
- id: gofumpt
name: gofumpt
entry: make fumpt
pass_filenames: false
types: [go]
language: system
- repo: https://github.com/TekWizely/pre-commit-golang
rev: v1.0.0-rc.1
hooks:
- id: go-staticcheck-repo-mod
- id: go-mod-tidy

83
Makefile Normal file → Executable file
View file

@ -1,9 +1,82 @@
integration-test:
# TODO figure out needed capabilities
sudo go test -count=1 -v ./... -tags=integration
#!/usr/bin/make -f
STATICCHECK_VERSION ?= 2024.1.1
LINT_VERSION ?= 1.60.3
BIN = bin
OUTPUT_LINT_DIR ?= $(abspath $(BIN))/linters
LINT_DIR = $(OUTPUT_LINT_DIR)/golangci-lint-$(LINT_VERSION)
TMP_DIR := .cache
STATICCHECK_DIR ?= $(abspath $(BIN))/staticcheck
STATICCHECK_VERSION_DIR ?= $(STATICCHECK_DIR)/$(STATICCHECK_VERSION)
GOFUMPT_VERSION ?= v0.7.0
GOFUMPT_DIR ?= $(abspath $(BIN))/gofumpt
GOFUMPT_VERSION_DIR ?= $(GOFUMPT_DIR)/$(GOFUMPT_VERSION)
GOPLS_VERSION ?= v0.16.2
GOPLS_DIR ?= $(abspath $(BIN))/gopls
GOPLS_VERSION_DIR ?= $(GOPLS_DIR)/$(GOPLS_VERSION)
GOPLS_TEMP_FILE := $(shell mktemp)
test:
go test -count=1 -v ./...
patch-example:
gopatch -d -p ./multinet.patch ./testdata/patch*
# Install linters
lint-install:
@rm -rf $(OUTPUT_LINT_DIR)
@mkdir -p $(OUTPUT_LINT_DIR)
@CGO_ENABLED=1 GOBIN=$(LINT_DIR) go install -trimpath github.com/golangci/golangci-lint/cmd/golangci-lint@v$(LINT_VERSION)
# Run linters
lint:
@if [ ! -d "$(LINT_DIR)" ]; then \
make lint-install; \
fi
$(LINT_DIR)/golangci-lint run
# Install staticcheck
staticcheck-install:
@rm -rf $(STATICCHECK_DIR)
@mkdir -p $(STATICCHECK_DIR)
@GOBIN=$(STATICCHECK_VERSION_DIR) go install honnef.co/go/tools/cmd/staticcheck@$(STATICCHECK_VERSION)
# Run staticcheck
staticcheck-run:
@if [ ! -d "$(STATICCHECK_VERSION_DIR)" ]; then \
make staticcheck-install; \
fi
@$(STATICCHECK_VERSION_DIR)/staticcheck ./...
# Install gopls
gopls-install:
@rm -rf $(GOPLS_DIR)
@mkdir -p $(GOPLS_DIR)
@GOBIN=$(GOPLS_VERSION_DIR) go install golang.org/x/tools/gopls@$(GOPLS_VERSION)
# Run gopls
gopls-run:
@if [ ! -d "$(GOPLS_VERSION_DIR)" ]; then \
make gopls-install; \
fi
$(GOPLS_VERSION_DIR)/gopls check $(SOURCES) 2>&1 >$(GOPLS_TEMP_FILE)
@if [[ $$(wc -l < $(GOPLS_TEMP_FILE)) -ne 0 ]]; then \
cat $(GOPLS_TEMP_FILE); \
exit 1; \
fi
rm $(GOPLS_TEMP_FILE)
# Install gofumpt
fumpt-install:
@rm -rf $(GOFUMPT_DIR)
@mkdir -p $(GOFUMPT_DIR)
@GOBIN=$(GOFUMPT_VERSION_DIR) go install mvdan.cc/gofumpt@$(GOFUMPT_VERSION)
# Run gofumpt
fumpt:
@if [ ! -d "$(GOFUMPT_VERSION_DIR)" ]; then \
make fumpt-install; \
fi
@echo "⇒ Processing gofumpt check"
$(GOFUMPT_VERSION_DIR)/gofumpt -l -w .

View file

@ -16,10 +16,31 @@ But sometimes you need to invent a bicycle.
## Usage
```golang
import "git.frostfs.info/TrueCloudLab/multinet"
import (
"context"
"net"
"net/netip"
"git.frostfs.info/TrueCloudLab/multinet"
)
d, err := multinet.NewDialer(Config{
Subnets: []string{"10.11.70.0/23", "192.168.123.0/24"},
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("10.11.70.0/23"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("10.11.70.42"),
netip.MustParseAddr("10.11.71.42"),
},
},
{
Prefix: netip.MustParsePrefix("192.168.123.0/24"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("192.168.123.42"),
netip.MustParseAddr("192.168.123.142"),
},
},
},
Balancer: multinet.BalancerTypeRoundRobin,
})
if err != nil {
@ -32,18 +53,3 @@ if err != nil {
}
// do stuff
```
### Updating interface state
`Multidialer` exposes `UpdateInterface()` method for updating state of a single link.
`NetlinkWatcher` can wrap `Multidialer` type and perform all updates automatically.
TODO: describe needed capabilities here.
## Patch
To perform refactoring (use `multinet.Dial` instead of `net.Dial`) using [gopatch](https://github.com/uber-go/gopatch):
```bash
gopatch -p ./multinet.patch <project directory>
```

View file

@ -32,15 +32,11 @@ type roundRobin struct {
func (r *roundRobin) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) {
next := int(r.i.Add(1))
for i := range s.Interfaces {
ii := s.Interfaces[(i+next)%len(s.Interfaces)]
if ii.Down {
continue
}
for i := range s.SourceIPs {
ii := s.SourceIPs[(i+next)%len(s.SourceIPs)]
dd := r.d.dialer
dd.LocalAddr = ii.LocalAddr
return r.d.dialContext(&dd, ctx, network, address)
dd.LocalAddr = &net.TCPAddr{IP: net.IP(ii.AsSlice())}
return r.d.dialContext(ctx, &dd, network, address)
}
return nil, fmt.Errorf("(*roundRobin).DialContext: %w", errNoSuitableNodeFound)
}
@ -50,15 +46,11 @@ type firstEnabled struct {
}
func (r *firstEnabled) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) {
for i := range s.Interfaces {
ii := s.Interfaces[i%len(s.Interfaces)]
if ii.Down {
continue
}
for i := range s.SourceIPs {
ii := s.SourceIPs[i]
dd := r.d.dialer
dd.LocalAddr = ii.LocalAddr
return r.d.dialContext(&dd, ctx, network, address)
dd.LocalAddr = &net.TCPAddr{IP: net.IP(ii.AsSlice())}
return r.d.dialContext(ctx, &dd, network, address)
}
return nil, fmt.Errorf("(*firstEnabled).DialContext: %w", errNoSuitableNodeFound)
}

View file

@ -1,31 +0,0 @@
package multinet
import (
"context"
"fmt"
"net"
)
var (
defaultDialer Multidialer
defaultDialerErr error
)
func init() {
var err error
defaultDialer, err = NewDialer(Config{
Balancer: BalancerTypeRoundRobin,
Subnets: []string{"0.0.0.0/0", "::/0"},
})
if err != nil {
defaultDialerErr = fmt.Errorf("failed to initialize default dialier: %w", err)
}
}
// Dial dials provided network and address using default dialer.
func Dial(network, address string) (net.Conn, error) {
if defaultDialerErr != nil {
return nil, defaultDialerErr
}
return defaultDialer.DialContext(context.Background(), network, address)
}

View file

@ -1,29 +0,0 @@
//go:build integration
package multinet
import (
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/require"
)
func TestDefaultDialer(t *testing.T) {
srv := startHTTP(t)
defer require.NoError(t, srv.Close())
conn, err := Dial("tcp", "localhost:8080")
require.NoError(t, err)
require.NoError(t, conn.Close())
}
func startHTTP(t *testing.T) *http.Server {
srv := &http.Server{Addr: ":8080"}
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "Test stub") })
go func() {
require.ErrorIs(t, srv.ListenAndServe(), http.ErrServerClosed)
}()
return srv
}

174
dialer.go
View file

@ -1,12 +1,10 @@
package multinet
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"sort"
"sync"
"time"
)
@ -20,12 +18,6 @@ type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// Multidialer is like Dialer, but supports link state updates.
type Multidialer interface {
Dialer
UpdateInterface(name string, addr netip.Addr, status bool)
}
var (
_ Dialer = (*net.Dialer)(nil)
_ Dialer = (*dialer)(nil)
@ -47,30 +39,29 @@ type dialer struct {
resolver net.Resolver
// See Config.FallbackDelay description.
fallbackDelay time.Duration
// Event handler.
eh EventHandler
fyrchik marked this conversation as resolved
Review

What is the usecase for this, besides tests?

What is the usecase for this, besides tests?
Review

Metrics, logging, failover

Metrics, logging, failover
}
// Subnet represents a single subnet, possibly routable from multiple interfaces.
// Subnet represents a single subnet, possibly routable from multiple source IPs.
type Subnet struct {
Mask netip.Prefix
Interfaces []Source
Prefix netip.Prefix
SourceIPs []netip.Addr
}
// Source represents a single source IP belonging to a particular subnet.
type Source struct {
Name string
LocalAddr *net.TCPAddr
Down bool
type EventHandler interface {
DialPerformed(sourceIP net.Addr, network, address string, err error)
}
// Config contains Multidialer configuration.
type Config struct {
// Routable subnets to prioritize in CIDR format.
Subnets []string
// Routable subnets.
Subnets []Subnet
// If true, the only configurd subnets available through this dialer.
// Otherwise, a failback to the net.DefaultDialer.
Restrict bool
// Dialer containes default options for the net.Dialer to use.
// LocalAddr is overriden.
// Dialer contains default options for the net.Dialer to use.
// LocalAddr is overridden.
Dialer net.Dialer
// Balancer specifies algorithm used to pick source address.
Balancer BalancerType
@ -83,47 +74,17 @@ type Config struct {
// If zero, a default delay of 300ms is used.
// A negative value disables Fast Fallback support.
FallbackDelay time.Duration
// InterfaceSource is custom `Interface`` source.
// If not specified, default implementation is used (`net.Interfaces()``).
InterfaceSource func() ([]Interface, error)
// DialContext is custom DialContext function.
// If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`).
DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
// EventHandler defines event handler.
EventHandler EventHandler
}
// NewDialer ...
func NewDialer(c Config) (Multidialer, error) {
var ifaces []Interface
var err error
if c.InterfaceSource != nil {
ifaces, err = c.InterfaceSource()
} else {
ifaces, err = systemInterfaces()
}
if err != nil {
return nil, err
}
sort.Slice(ifaces, func(i, j int) bool {
return ifaces[i].Name() < ifaces[j].Name()
})
var sources []iface
for i := range ifaces {
info, err := processIface(ifaces[i])
if err != nil {
return nil, err
}
sources = append(sources, info)
}
func NewDialer(c Config) (Dialer, error) {
var d dialer
for _, subnet := range c.Subnets {
s, err := processSubnet(subnet, sources)
if err != nil {
return nil, err
}
d.subnets = append(d.subnets, s)
}
d.subnets = c.Subnets
switch c.Balancer {
case BalancerTypeNoop:
@ -145,66 +106,15 @@ func NewDialer(c Config) (Multidialer, error) {
d.customDialContext = c.DialContext
}
if c.EventHandler != nil {
d.eh = c.EventHandler
} else {
d.eh = noopEventHandler{}
}
return &d, nil
}
type iface struct {
name string
addrs []netip.Prefix
down bool
}
func processIface(info Interface) (iface, error) {
ips, err := info.Addrs()
if err != nil {
return iface{}, err
}
var addrs []netip.Prefix
for i := range ips {
p, err := netip.ParsePrefix(ips[i].String())
if err != nil {
return iface{}, err
}
addrs = append(addrs, p)
}
return iface{name: info.Name(), addrs: addrs, down: info.Down()}, nil
}
func processSubnet(subnet string, sources []iface) (Subnet, error) {
s, err := netip.ParsePrefix(subnet)
if err != nil {
return Subnet{}, err
}
var ifs []Source
for _, source := range sources {
for i := range source.addrs {
src := source.addrs[i].Addr()
if s.Contains(src) {
ifs = append(ifs, Source{
Name: source.name,
LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())},
Down: source.down,
})
}
}
}
sort.Slice(ifs, func(i, j int) bool {
if ifs[i].Name != ifs[j].Name {
return ifs[i].Name < ifs[j].Name
}
return bytes.Compare(ifs[i].LocalAddr.IP, ifs[j].LocalAddr.IP) == -1
})
return Subnet{
Mask: s,
Interfaces: ifs,
}, nil
}
// DialContext implements the Dialer interface.
// Hostnames for address are currently not supported.
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
@ -371,7 +281,7 @@ func (d *dialer) dialAddr(ctx context.Context, network, address string, addr net
defer d.mtx.RUnlock()
for i := range d.subnets {
if d.subnets[i].Mask.Contains(addr.Addr()) {
if d.subnets[i].Prefix.Contains(addr.Addr()) {
return d.balancer.DialContext(ctx, &d.subnets[i], network, address)
}
}
@ -379,37 +289,19 @@ func (d *dialer) dialAddr(ctx context.Context, network, address string, addr net
if d.restrict {
return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address)
}
return d.dialContext(&d.dialer, ctx, network, address)
return d.dialContext(ctx, &d.dialer, network, address)
}
func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
func (d *dialer) dialContext(ctx context.Context, nd *net.Dialer, network, address string) (net.Conn, error) {
var conn net.Conn
var err error
if h := d.customDialContext; h != nil {
return h(nd, ctx, network, address)
}
return nd.DialContext(ctx, network, address)
}
// UpdateInterface implements the Multidialer interface.
// Updating address on a specific interface is currently not supported.
func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) {
d.mtx.Lock()
defer d.mtx.Unlock()
for i := range d.subnets {
for j := range d.subnets[i].Interfaces {
matchIface := d.subnets[i].Interfaces[j].Name == iface
if matchIface {
d.subnets[i].Interfaces[j].Down = !up
continue
}
a, _ := netip.AddrFromSlice(d.subnets[i].Interfaces[j].LocalAddr.IP)
matchAddr := a.IsUnspecified() || addr == a
if matchAddr {
d.subnets[i].Interfaces[j].Down = !up
}
}
conn, err = h(nd, ctx, network, address)
} else {
conn, err = nd.DialContext(ctx, network, address)
}
d.eh.DialPerformed(nd.LocalAddr, network, address, err)
return conn, err
}
// splitByType divides an address list into two categories:
@ -429,3 +321,7 @@ func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks
}
return
}
type noopEventHandler struct{}
func (s noopEventHandler) DialPerformed(net.Addr, string, string, error) {}

View file

@ -3,6 +3,7 @@ package multinet
import (
"context"
"net"
"net/netip"
"testing"
"time"
@ -14,8 +15,15 @@ func TestHostnameResolveIPv4(t *testing.T) {
resolvedAddr := "10.11.12.180:8080"
resolved := false
d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24"},
InterfaceSource: testInterfacesV4,
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("10.11.12.0/24"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("10.11.12.101"),
netip.MustParseAddr("10.11.12.102"),
},
},
},
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
@ -42,8 +50,15 @@ func TestHostnameResolveIPv6(t *testing.T) {
ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195")
resolved := false
d, err := NewDialer(Config{
Subnets: []string{"2001:db8:85a3:8d3::/64"},
InterfaceSource: testInterfacesV6,
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("2001:db8:85a3:8d3::/64"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("2001:db8:85a3:8d3:1319:8a2e:370:7348"),
netip.MustParseAddr("2001:db8:85a3:8d3:1319:8a2e:370:8192"),
},
},
},
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
@ -65,70 +80,6 @@ func TestHostnameResolveIPv6(t *testing.T) {
require.True(t, resolved)
}
func testInterfacesV4() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.101/24",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.102/24",
},
},
},
}, nil
}
func testInterfacesV6() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:7348/64",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:8192/64",
},
},
},
}, nil
}
type testInterface struct {
name string
addrs []net.Addr
down bool
}
func (i *testInterface) Name() string { return i.name }
func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil }
func (i *testInterface) Down() bool { return i.down }
type testAddr struct {
network string
str string
}
func (a *testAddr) Network() string { return a.network }
func (a *testAddr) String() string { return a.str }
type testDnsConn struct {
wantName string
ipv4 []byte

View file

@ -1,166 +0,0 @@
//go:build integration
package multinet
import (
"net"
"net/netip"
"runtime"
"testing"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
)
func TestDialer(t *testing.T) {
runInNewNamespace(t, "2 interfaces with multiple routes in different subnets", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.10/23", "4.4.4.4/8"},
"testdev2": {"1.2.30.11/23", "4.4.4.5/8"},
})
// Do not use `t.Run` because everything should be executed in a single OS thread.
{ // Restrict to a single subnet.
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("1.2.30.0/23"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 10}}},
{Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}},
},
},
}, d.(*dialer).subnets)
}
{ // Restrict to two subnets.
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23", "4.0.0.0/8"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("1.2.30.0/23"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 10}}},
{Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}},
},
},
{
Mask: netip.MustParsePrefix("4.0.0.0/8"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{4, 4, 4, 4}}},
{Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{4, 4, 4, 5}}},
},
},
}, d.(*dialer).subnets)
}
})
runInNewNamespace(t, "4 interfaces, 2 for data, 2 internal", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"internal1": {"192.168.0.1/16"},
"internal2": {"192.168.0.2/16"},
"data1": {"10.11.12.101/24"},
"data2": {"10.11.12.102/24"},
})
d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24", "192.168.0.0/16"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("10.11.12.0/24"),
Interfaces: []Source{
{Name: "data1", LocalAddr: &net.TCPAddr{IP: net.IP{10, 11, 12, 101}}},
{Name: "data2", LocalAddr: &net.TCPAddr{IP: net.IP{10, 11, 12, 102}}},
},
},
{
Mask: netip.MustParsePrefix("192.168.0.0/16"),
Interfaces: []Source{
{Name: "internal1", LocalAddr: &net.TCPAddr{IP: net.IP{192, 168, 0, 1}}},
{Name: "internal2", LocalAddr: &net.TCPAddr{IP: net.IP{192, 168, 0, 2}}},
},
},
}, d.(*dialer).subnets)
})
runInNewNamespace(t, "with ipv6", func(t *testing.T, ns netns.NsHandle) {
addr1 := "2001:db8:85a3:8d3:1319:8a2e:370:7348/64"
addr2 := "2001:db8:85a3:8d3:1319:8a2e:370:8192/64"
setup(t, map[string][]string{
"testdev1": {addr1},
"testdev2": {addr2},
})
// Do not use `t.Run` because everything should be executed in a single OS thread.
{ // Restrict to a single subnet.
d, err := NewDialer(Config{
Subnets: []string{"2001:db8:85a3:8d3::/64"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("2001:db8:85a3:8d3::/64"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: mustParseIPv6(t, addr1)},
{Name: "testdev2", LocalAddr: mustParseIPv6(t, addr2)},
},
},
}, d.(*dialer).subnets)
}
})
}
func mustParseIPv6(t *testing.T, s string) *net.TCPAddr {
ip, _, err := net.ParseCIDR(s)
require.NoError(t, err)
return &net.TCPAddr{IP: ip}
}
func setup(t *testing.T, config map[string][]string) {
for name, ips := range config {
link := createLink(t, name)
for i := range ips {
ip, err := netlink.ParseIPNet(ips[i])
require.NoError(t, err)
require.NoError(t, netlink.AddrAdd(link, &netlink.Addr{IPNet: ip}))
}
}
}
func createLink(t *testing.T, name string) netlink.Link {
require.NoError(t, netlink.LinkAdd(&netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: name}}))
link, err := netlink.LinkByName(name)
require.NoError(t, err)
require.NoError(t, netlink.LinkSetUp(link))
return link
}
func runInNewNamespace(t *testing.T, name string, f func(t *testing.T, ns netns.NsHandle)) {
t.Run(name, func(t *testing.T) {
// To avoid messing with host network settings,
// we create a new names space and execute tests in it.
// Switching thread can move us to a different namespace, thus this line.
runtime.LockOSThread()
defer runtime.UnlockOSThread()
origns, err := netns.Get()
require.NoError(t, err)
defer origns.Close()
defer netns.Set(origns)
newns, err := netns.New()
require.NoError(t, err)
defer newns.Close()
f(t, newns)
})
}

View file

@ -2,17 +2,20 @@ package multinet
import (
"context"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
)
func TestInterfacesDown(t *testing.T) {
func TestNoSourceIPs(t *testing.T) {
t.Run("noop balancer", func(t *testing.T) {
d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24"},
InterfaceSource: testDownInterfaces,
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("10.11.12.0/24"),
},
},
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "10.11.12.254:8080")
@ -21,9 +24,12 @@ func TestInterfacesDown(t *testing.T) {
})
t.Run("round robin balancer", func(t *testing.T) {
d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24"},
InterfaceSource: testDownInterfaces,
Balancer: BalancerTypeRoundRobin,
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("10.11.12.0/24"),
},
},
Balancer: BalancerTypeRoundRobin,
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "10.11.12.254:8080")
@ -31,28 +37,3 @@ func TestInterfacesDown(t *testing.T) {
require.Nil(t, conn)
})
}
func testDownInterfaces() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.101/24",
},
},
down: true,
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.102/24",
},
},
down: true,
},
}, nil
}

9
go.mod
View file

@ -1,13 +1,10 @@
module git.frostfs.info/TrueCloudLab/multinet
go 1.20
go 1.22
require (
github.com/stretchr/testify v1.8.4
github.com/vishvananda/netlink v1.1.0
github.com/vishvananda/netns v0.0.4
golang.org/x/net v0.17.0
golang.org/x/sys v0.13.0
github.com/stretchr/testify v1.9.0
golang.org/x/net v0.26.0
)
require (

BIN
go.sum

Binary file not shown.

View file

@ -1,67 +0,0 @@
package multinet
import (
"net/netip"
"sync"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
type NetlinkWatcher struct {
d Multidialer
linkUpdates chan netlink.LinkUpdate
addrUpdates chan netlink.AddrUpdate
done chan struct{}
wg sync.WaitGroup
}
func NewNetlinkWatcher(d Multidialer) *NetlinkWatcher {
return &NetlinkWatcher{
d: d,
addrUpdates: make(chan netlink.AddrUpdate, 1),
linkUpdates: make(chan netlink.LinkUpdate, 1),
done: make(chan struct{}),
}
}
func (w *NetlinkWatcher) Start() error {
if err := netlink.LinkSubscribe(w.linkUpdates, w.done); err != nil {
return err
}
if err := netlink.AddrSubscribe(w.addrUpdates, w.done); err != nil {
close(w.done)
return err
}
w.wg.Add(1)
go w.watch()
return nil
}
func (w *NetlinkWatcher) watch() {
defer w.wg.Done()
for {
select {
case <-w.done:
return
case update := <-w.addrUpdates:
// Wont work if an multiple interfaces share IP address.
// Should not happen in practice.
ip, ok := netip.AddrFromSlice(update.LinkAddress.IP)
if !ok {
continue
}
w.d.UpdateInterface("", ip, update.NewAddr)
case update := <-w.linkUpdates:
up := update.Flags&unix.IFF_UP != 0
w.d.UpdateInterface(update.Link.Attrs().Name, netip.Addr{}, up)
}
}
}
func (w *NetlinkWatcher) Stop() {
close(w.done)
w.wg.Wait()
}

View file

@ -1,157 +0,0 @@
//go:build integration
package multinet
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
)
func Test_NetlinkWatcher(t *testing.T) {
runInNewNamespace(t, "noop balancer, disable interface", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.11/23"},
"testdev2": {"1.2.30.12/23"},
})
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
w := NewNetlinkWatcher(d)
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
link, err := netlink.LinkByName("testdev1")
require.NoError(t, err)
require.NoError(t, netlink.LinkSetDown(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr2)
require.NoError(t, netlink.LinkSetUp(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr1)
})
runInNewNamespace(t, "noop balancer, remove address", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.11/23"},
"testdev2": {"1.2.30.12/23"},
})
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
w := NewNetlinkWatcher(d)
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
link, err := netlink.LinkByName("testdev1")
require.NoError(t, err)
ip, err := netlink.ParseIPNet("1.2.30.11/23")
require.NoError(t, err)
require.NoError(t, netlink.AddrDel(link, &netlink.Addr{IPNet: ip}))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr2)
require.NoError(t, netlink.AddrAdd(link, &netlink.Addr{IPNet: ip}))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr1)
})
runInNewNamespace(t, "round-robin balancer, disable interface", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.11/23"},
"testdev2": {"1.2.30.12/23"},
})
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
Balancer: BalancerTypeRoundRobin,
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
w := NewNetlinkWatcher(d)
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr2)
link, err := netlink.LinkByName("testdev1")
require.NoError(t, err)
require.NoError(t, netlink.LinkSetDown(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr2)
require.NoError(t, netlink.LinkSetUp(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr2)
})
}
func checkDialAddr(t *testing.T, d Multidialer, ch chan net.Addr, expected net.Addr) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err := d.DialContext(ctx, "tcp", "1.2.30.42:12345")
require.NoError(t, err)
select {
case addr := <-ch:
require.Equal(t, expected, addr)
default:
require.Fail(t, "DialContext() was not called")
}
}

View file

@ -1,38 +0,0 @@
package multinet
import "net"
// Interface provides information about net.Interface.
type Interface interface {
Name() string
Addrs() ([]net.Addr, error)
Down() bool
}
type netInterface struct {
iface net.Interface
}
func (i *netInterface) Name() string {
return i.iface.Name
}
func (i *netInterface) Addrs() ([]net.Addr, error) {
return i.iface.Addrs()
}
func (i *netInterface) Down() bool {
return i.iface.Flags&net.FlagUp == 0
}
func systemInterfaces() ([]Interface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var result []Interface
for _, iface := range ifaces {
result = append(result, &netInterface{iface: iface})
}
return result, nil
}

View file

@ -1,7 +0,0 @@
@@
@@
+import "git.frostfs.info/TrueCloudLab/multinet"
-import "net"
-net.Dial(...)
+multinet.Dial(...)

19
testdata/patch_0.go vendored
View file

@ -1,19 +0,0 @@
package main
import (
"log"
"net"
)
const addr = "s01.frostfs.devenv:8080"
func main() {
_, err := net.Dial(getNetwork(), addr)
if err != nil {
log.Fatal(err)
}
}
func getNetwork() string {
return "tcp"
}

14
testdata/patch_1.go vendored
View file

@ -1,14 +0,0 @@
package main
import (
"log"
"net"
)
func main() {
ip := net.IPv4(192, 168, 0, 10)
_, err := net.Dial("tcp", ip.String()+":8080")
if err != nil {
log.Fatal(err)
}
}