Compare commits

..

No commits in common. "master" and "master" have entirely different histories.

20 changed files with 681 additions and 548 deletions

View file

@ -1,21 +0,0 @@
name: DCO action
on: [pull_request]
jobs:
dco:
name: DCO
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
- name: Run commit format checker
uses: https://git.frostfs.info/TrueCloudLab/dco-go@v3
with:
from: 'origin/${{ github.event.pull_request.base.ref }}'

View file

@ -1,25 +0,0 @@
name: Pre-commit hooks
on: [pull_request]
jobs:
precommit:
name: Pre-commit
env:
# Skip pre-commit hooks which are executed by other actions.
SKIP: make-lint,go-staticcheck-repo-mod,go-unit-tests,gofumpt
runs-on: ubuntu-22.04
# If we use actions/setup-python from either Github or Gitea,
# the line above fails with a cryptic error about not being able to find python.
# So install everything manually.
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.23
- name: Set up Python
run: |
apt update
apt install -y pre-commit
- name: Run pre-commit
run: pre-commit run --color=always --hook-stage manual --all-files

View file

@ -1,111 +0,0 @@
name: Tests and linters
on: [pull_request]
jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
cache: true
- name: Install linters
run: make lint-install
- name: Run linters
run: make lint
tests:
name: Tests
runs-on: ubuntu-latest
strategy:
matrix:
go_versions: [ '1.22', '1.23' ]
fail-fast: false
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '${{ matrix.go_versions }}'
cache: true
- name: Run tests
run: make test
tests-race:
name: Tests with -race
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
cache: true
- name: Run tests
run: go test ./... -count=1 -race
staticcheck:
name: Staticcheck
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
cache: true
- name: Install staticcheck
run: make staticcheck-install
- name: Run staticcheck
run: make staticcheck-run
gopls:
name: gopls check
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
cache: true
- name: Install gopls
run: make gopls-install
- name: Run gopls
run: make gopls-run
fumpt:
name: Run gofumpt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
cache: true
- name: Install gofumpt
run: make fumpt-install
- name: Run gofumpt
run: |
make fumpt
git diff --exit-code --quiet

View file

@ -1,22 +0,0 @@
name: Vulncheck
on: [pull_request]
jobs:
vulncheck:
name: Vulncheck
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v3
with:
go-version: '1.23'
- name: Install govulncheck
run: go install golang.org/x/vuln/cmd/govulncheck@latest
- name: Run govulncheck
run: govulncheck ./...

1
.gitattributes vendored
View file

@ -1 +0,0 @@
/go.sum -diff

22
.gitignore vendored
View file

@ -1,22 +0,0 @@
# IDE
.idea
.vscode
# Vendoring
vendor
# tempfiles
.DS_Store
*~
.cache
temp
tmp
# binary
bin/
release/
# coverage
coverage.txt
coverage.html

View file

@ -1,75 +0,0 @@
# This file contains all available configuration options
# with their default values.
# options for analysis running
run:
# timeout for analysis, e.g. 30s, 5m, default is 1m
timeout: 20m
# include test files or not, default is true
tests: false
# output configuration options
output:
# colored-line-number|line-number|json|tab|checkstyle|code-climate, default is "colored-line-number"
formats:
- format: tab
# all available settings of specific linters
linters-settings:
exhaustive:
# indicates that switch statements are to be considered exhaustive if a
# 'default' case is present, even if all enum members aren't listed in the
# switch
default-signifies-exhaustive: true
govet:
# report about shadowed variables
check-shadowing: false
staticcheck:
checks: ["all", "-SA1019"] # TODO Enable SA1019 after deprecated warning are fixed.
funlen:
lines: 80 # default 60
statements: 60 # default 40
gocognit:
min-complexity: 40 # default 30
unused:
field-writes-are-uses: false
exported-fields-are-used: false
local-variables-are-used: false
linters:
enable:
# mandatory linters
- govet
- revive
# some default golangci-lint linters
- errcheck
- gosimple
- godot
- ineffassign
- staticcheck
- typecheck
- unused
# extra linters
- bidichk
- durationcheck
- exhaustive
- copyloopvar
- gofmt
- goimports
- misspell
- predeclared
- reassign
- whitespace
- containedctx
- funlen
- gocognit
- contextcheck
- importas
- perfsprint
- testifylint
- protogetter
disable-all: true
fast: false

View file

@ -1,56 +0,0 @@
ci:
autofix_prs: false
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
- id: check-executables-have-shebangs
- id: check-shebang-scripts-are-executable
- id: check-merge-conflict
- id: check-json
- id: check-xml
- id: check-yaml
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: end-of-file-fixer
exclude: "(.key|.svg)$"
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.9.0.6
hooks:
- id: shellcheck
- repo: local
hooks:
- id: make-lint
name: Run Make Lint
entry: make lint
language: system
pass_filenames: false
- repo: local
hooks:
- id: go-unit-tests
name: go unit tests
entry: make test GOFLAGS=''
pass_filenames: false
types: [go]
language: system
- repo: local
hooks:
- id: gofumpt
name: gofumpt
entry: make fumpt
pass_filenames: false
types: [go]
language: system
- repo: https://github.com/TekWizely/pre-commit-golang
rev: v1.0.0-rc.1
hooks:
- id: go-staticcheck-repo-mod
- id: go-mod-tidy

82
Makefile Executable file → Normal file
View file

@ -1,82 +1,6 @@
#!/usr/bin/make -f
STATICCHECK_VERSION ?= 2024.1.1
LINT_VERSION ?= 1.60.3
BIN = bin
OUTPUT_LINT_DIR ?= $(abspath $(BIN))/linters
LINT_DIR = $(OUTPUT_LINT_DIR)/golangci-lint-$(LINT_VERSION)
TMP_DIR := .cache
STATICCHECK_DIR ?= $(abspath $(BIN))/staticcheck
STATICCHECK_VERSION_DIR ?= $(STATICCHECK_DIR)/$(STATICCHECK_VERSION)
GOFUMPT_VERSION ?= v0.7.0
GOFUMPT_DIR ?= $(abspath $(BIN))/gofumpt
GOFUMPT_VERSION_DIR ?= $(GOFUMPT_DIR)/$(GOFUMPT_VERSION)
GOPLS_VERSION ?= v0.16.2
GOPLS_DIR ?= $(abspath $(BIN))/gopls
GOPLS_VERSION_DIR ?= $(GOPLS_DIR)/$(GOPLS_VERSION)
GOPLS_TEMP_FILE := $(shell mktemp)
integration-test:
# TODO figure out needed capabilities
sudo go test -count=1 -v ./... -tags=integration
test:
go test -count=1 -v ./...
# Install linters
lint-install:
@rm -rf $(OUTPUT_LINT_DIR)
@mkdir -p $(OUTPUT_LINT_DIR)
@CGO_ENABLED=1 GOBIN=$(LINT_DIR) go install -trimpath github.com/golangci/golangci-lint/cmd/golangci-lint@v$(LINT_VERSION)
# Run linters
lint:
@if [ ! -d "$(LINT_DIR)" ]; then \
make lint-install; \
fi
$(LINT_DIR)/golangci-lint run
# Install staticcheck
staticcheck-install:
@rm -rf $(STATICCHECK_DIR)
@mkdir -p $(STATICCHECK_DIR)
@GOBIN=$(STATICCHECK_VERSION_DIR) go install honnef.co/go/tools/cmd/staticcheck@$(STATICCHECK_VERSION)
# Run staticcheck
staticcheck-run:
@if [ ! -d "$(STATICCHECK_VERSION_DIR)" ]; then \
make staticcheck-install; \
fi
@$(STATICCHECK_VERSION_DIR)/staticcheck ./...
# Install gopls
gopls-install:
@rm -rf $(GOPLS_DIR)
@mkdir -p $(GOPLS_DIR)
@GOBIN=$(GOPLS_VERSION_DIR) go install golang.org/x/tools/gopls@$(GOPLS_VERSION)
# Run gopls
gopls-run:
@if [ ! -d "$(GOPLS_VERSION_DIR)" ]; then \
make gopls-install; \
fi
$(GOPLS_VERSION_DIR)/gopls check $(SOURCES) 2>&1 >$(GOPLS_TEMP_FILE)
@if [[ $$(wc -l < $(GOPLS_TEMP_FILE)) -ne 0 ]]; then \
cat $(GOPLS_TEMP_FILE); \
exit 1; \
fi
rm $(GOPLS_TEMP_FILE)
# Install gofumpt
fumpt-install:
@rm -rf $(GOFUMPT_DIR)
@mkdir -p $(GOFUMPT_DIR)
@GOBIN=$(GOFUMPT_VERSION_DIR) go install mvdan.cc/gofumpt@$(GOFUMPT_VERSION)
# Run gofumpt
fumpt:
@if [ ! -d "$(GOFUMPT_VERSION_DIR)" ]; then \
make fumpt-install; \
fi
@echo "⇒ Processing gofumpt check"
$(GOFUMPT_VERSION_DIR)/gofumpt -l -w .

View file

@ -16,31 +16,10 @@ But sometimes you need to invent a bicycle.
## Usage
```golang
import (
"context"
"net"
"net/netip"
"git.frostfs.info/TrueCloudLab/multinet"
)
import "git.frostfs.info/TrueCloudLab/multinet"
d, err := multinet.NewDialer(Config{
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("10.11.70.0/23"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("10.11.70.42"),
netip.MustParseAddr("10.11.71.42"),
},
},
{
Prefix: netip.MustParsePrefix("192.168.123.0/24"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("192.168.123.42"),
netip.MustParseAddr("192.168.123.142"),
},
},
},
Subnets: []string{"10.11.70.0/23", "192.168.123.0/24"},
Balancer: multinet.BalancerTypeRoundRobin,
})
if err != nil {
@ -53,3 +32,10 @@ if err != nil {
}
// do stuff
```
### Updating interface state
`Multidialer` exposes `UpdateInterface()` method for updating state of a single link.
`NetlinkWatcher` can wrap `Multidialer` type and perform all updates automatically.
TODO: describe needed capabilities here.

View file

@ -3,7 +3,6 @@ package multinet
import (
"context"
"errors"
"fmt"
"net"
"sync/atomic"
)
@ -19,8 +18,6 @@ const (
BalancerTypeRoundRobin BalancerType = "roundrobin"
)
var errNoSuitableNodeFound = errors.New("no suitale node found")
type balancer interface {
DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error)
}
@ -32,13 +29,17 @@ type roundRobin struct {
func (r *roundRobin) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) {
next := int(r.i.Add(1))
for i := range s.SourceIPs {
ii := s.SourceIPs[(i+next)%len(s.SourceIPs)]
for i := range s.Interfaces {
ii := s.Interfaces[(i+next)%len(s.Interfaces)]
if ii.Down {
continue
}
dd := r.d.dialer
dd.LocalAddr = &net.TCPAddr{IP: net.IP(ii.AsSlice())}
return r.d.dialContext(ctx, &dd, network, address)
dd.LocalAddr = ii.LocalAddr
return r.d.dialContext(&dd, ctx, network, address)
}
return nil, fmt.Errorf("(*roundRobin).DialContext: %w", errNoSuitableNodeFound)
return nil, errors.New("(*roundRobin).DialContext: no suitale node found")
}
type firstEnabled struct {
@ -46,11 +47,15 @@ type firstEnabled struct {
}
func (r *firstEnabled) DialContext(ctx context.Context, s *Subnet, network, address string) (net.Conn, error) {
for i := range s.SourceIPs {
ii := s.SourceIPs[i]
for i := range s.Interfaces {
ii := s.Interfaces[i%len(s.Interfaces)]
if ii.Down {
continue
}
dd := r.d.dialer
dd.LocalAddr = &net.TCPAddr{IP: net.IP(ii.AsSlice())}
return r.d.dialContext(ctx, &dd, network, address)
dd.LocalAddr = ii.LocalAddr
return r.d.dialContext(&dd, ctx, network, address)
}
return nil, fmt.Errorf("(*firstEnabled).DialContext: %w", errNoSuitableNodeFound)
return nil, errors.New("(*firstEnabled).DialContext: no suitale node found")
}

173
dialer.go
View file

@ -1,10 +1,12 @@
package multinet
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"sort"
"sync"
"time"
)
@ -18,6 +20,12 @@ type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// Multidialer is like Dialer, but supports link state updates.
type Multidialer interface {
Dialer
UpdateInterface(name string, addr netip.Addr, status bool) error
}
var (
_ Dialer = (*net.Dialer)(nil)
_ Dialer = (*dialer)(nil)
@ -39,29 +47,30 @@ type dialer struct {
resolver net.Resolver
// See Config.FallbackDelay description.
fallbackDelay time.Duration
// Event handler.
eh EventHandler
}
// Subnet represents a single subnet, possibly routable from multiple source IPs.
// Subnet represents a single subnet, possibly routable from multiple interfaces.
type Subnet struct {
Prefix netip.Prefix
SourceIPs []netip.Addr
Mask netip.Prefix
Interfaces []Source
}
type EventHandler interface {
DialPerformed(sourceIP net.Addr, network, address string, err error)
// Source represents a single source IP belonging to a particular subnet.
type Source struct {
Name string
LocalAddr *net.TCPAddr
Down bool
}
// Config contains Multidialer configuration.
type Config struct {
// Routable subnets.
Subnets []Subnet
// Routable subnets to prioritize in CIDR format.
Subnets []string
// If true, the only configurd subnets available through this dialer.
// Otherwise, a failback to the net.DefaultDialer.
Restrict bool
// Dialer contains default options for the net.Dialer to use.
// LocalAddr is overridden.
// Dialer containes default options for the net.Dialer to use.
// LocalAddr is overriden.
Dialer net.Dialer
// Balancer specifies algorithm used to pick source address.
Balancer BalancerType
@ -74,17 +83,47 @@ type Config struct {
// If zero, a default delay of 300ms is used.
// A negative value disables Fast Fallback support.
FallbackDelay time.Duration
// InterfaceSource is custom `Interface`` source.
// If not specified, default implementation is used (`net.Interfaces()``).
InterfaceSource func() ([]Interface, error)
// DialContext is custom DialContext function.
// If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`).
DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error)
// EventHandler defines event handler.
EventHandler EventHandler
}
// NewDialer ...
func NewDialer(c Config) (Dialer, error) {
func NewDialer(c Config) (Multidialer, error) {
var ifaces []Interface
var err error
if c.InterfaceSource != nil {
ifaces, err = c.InterfaceSource()
} else {
ifaces, err = systemInterfaces()
}
if err != nil {
return nil, err
}
sort.Slice(ifaces, func(i, j int) bool {
return ifaces[i].Name() < ifaces[j].Name()
})
var sources []iface
for i := range ifaces {
info, err := processIface(ifaces[i])
if err != nil {
return nil, err
}
sources = append(sources, info)
}
var d dialer
d.subnets = c.Subnets
for _, subnet := range c.Subnets {
s, err := processSubnet(subnet, sources)
if err != nil {
return nil, err
}
d.subnets = append(d.subnets, s)
}
switch c.Balancer {
case BalancerTypeNoop:
@ -106,20 +145,69 @@ func NewDialer(c Config) (Dialer, error) {
d.customDialContext = c.DialContext
}
if c.EventHandler != nil {
d.eh = c.EventHandler
} else {
d.eh = noopEventHandler{}
return &d, nil
}
type iface struct {
name string
addrs []netip.Prefix
}
func processIface(info Interface) (iface, error) {
ips, err := info.Addrs()
if err != nil {
return iface{}, err
}
return &d, nil
var addrs []netip.Prefix
for i := range ips {
p, err := netip.ParsePrefix(ips[i].String())
if err != nil {
return iface{}, err
}
addrs = append(addrs, p)
}
return iface{name: info.Name(), addrs: addrs}, nil
}
func processSubnet(subnet string, sources []iface) (Subnet, error) {
s, err := netip.ParsePrefix(subnet)
if err != nil {
return Subnet{}, err
}
var ifs []Source
for _, source := range sources {
for i := range source.addrs {
src := source.addrs[i].Addr()
if s.Contains(src) {
ifs = append(ifs, Source{
Name: source.name,
LocalAddr: &net.TCPAddr{IP: net.IP(src.AsSlice())},
})
}
}
}
sort.Slice(ifs, func(i, j int) bool {
if ifs[i].Name != ifs[j].Name {
return ifs[i].Name < ifs[j].Name
}
return bytes.Compare(ifs[i].LocalAddr.IP, ifs[j].LocalAddr.IP) == -1
})
return Subnet{
Mask: s,
Interfaces: ifs,
}, nil
}
// DialContext implements the Dialer interface.
// Hostnames for address are currently not supported.
func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
addr, err := netip.ParseAddrPort(address)
if err != nil { // try resolve as hostname
if err != nil { //try resolve as hostname
return d.dialContextHostname(ctx, network, address)
}
return d.dialAddr(ctx, network, address, addr)
@ -281,7 +369,7 @@ func (d *dialer) dialAddr(ctx context.Context, network, address string, addr net
defer d.mtx.RUnlock()
for i := range d.subnets {
if d.subnets[i].Prefix.Contains(addr.Addr()) {
if d.subnets[i].Mask.Contains(addr.Addr()) {
return d.balancer.DialContext(ctx, &d.subnets[i], network, address)
}
}
@ -289,19 +377,38 @@ func (d *dialer) dialAddr(ctx context.Context, network, address string, addr net
if d.restrict {
return nil, fmt.Errorf("no suitable interface for: [%s]%s", network, address)
}
return d.dialContext(ctx, &d.dialer, network, address)
return d.dialContext(&d.dialer, ctx, network, address)
}
func (d *dialer) dialContext(ctx context.Context, nd *net.Dialer, network, address string) (net.Conn, error) {
var conn net.Conn
var err error
func (d *dialer) dialContext(nd *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if h := d.customDialContext; h != nil {
conn, err = h(nd, ctx, network, address)
} else {
conn, err = nd.DialContext(ctx, network, address)
return h(nd, ctx, network, address)
}
d.eh.DialPerformed(nd.LocalAddr, network, address, err)
return conn, err
return nd.DialContext(ctx, network, address)
}
// UpdateInterface implements the Multidialer interface.
// Updating address on a specific interface is currently not supported.
func (d *dialer) UpdateInterface(iface string, addr netip.Addr, up bool) error {
d.mtx.Lock()
defer d.mtx.Unlock()
for i := range d.subnets {
for j := range d.subnets[i].Interfaces {
matchIface := d.subnets[i].Interfaces[j].Name == iface
if matchIface {
d.subnets[i].Interfaces[j].Down = !up
continue
}
a, _ := netip.AddrFromSlice(d.subnets[i].Interfaces[j].LocalAddr.IP)
matchAddr := a.IsUnspecified() || addr == a
if matchAddr {
d.subnets[i].Interfaces[j].Down = !up
}
}
}
return nil
}
// splitByType divides an address list into two categories:
@ -321,7 +428,3 @@ func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks
}
return
}
type noopEventHandler struct{}
func (s noopEventHandler) DialPerformed(net.Addr, string, string, error) {}

View file

@ -3,7 +3,6 @@ package multinet
import (
"context"
"net"
"net/netip"
"testing"
"time"
@ -15,15 +14,8 @@ func TestHostnameResolveIPv4(t *testing.T) {
resolvedAddr := "10.11.12.180:8080"
resolved := false
d, err := NewDialer(Config{
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("10.11.12.0/24"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("10.11.12.101"),
netip.MustParseAddr("10.11.12.102"),
},
},
},
Subnets: []string{"10.11.12.0/24"},
InterfaceSource: testInterfacesV4,
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
@ -50,15 +42,8 @@ func TestHostnameResolveIPv6(t *testing.T) {
ipv6 := net.ParseIP("2001:db8:85a3:8d3:1319:8a2e:370:8195")
resolved := false
d, err := NewDialer(Config{
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("2001:db8:85a3:8d3::/64"),
SourceIPs: []netip.Addr{
netip.MustParseAddr("2001:db8:85a3:8d3:1319:8a2e:370:7348"),
netip.MustParseAddr("2001:db8:85a3:8d3:1319:8a2e:370:8192"),
},
},
},
Subnets: []string{"2001:db8:85a3:8d3::/64"},
InterfaceSource: testInterfacesV6,
DialContext: func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) {
if resolvedAddr == address {
resolved = true
@ -80,6 +65,68 @@ func TestHostnameResolveIPv6(t *testing.T) {
require.True(t, resolved)
}
func testInterfacesV4() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.101/24",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "10.11.12.102/24",
},
},
},
}, nil
}
func testInterfacesV6() ([]Interface, error) {
return []Interface{
&testInterface{
name: "data1",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:7348/64",
},
},
},
&testInterface{
name: "data2",
addrs: []net.Addr{
&testAddr{
network: "tcp",
str: "2001:db8:85a3:8d3:1319:8a2e:370:8192/64",
},
},
},
}, nil
}
type testInterface struct {
name string
addrs []net.Addr
}
func (i *testInterface) Name() string { return i.name }
func (i *testInterface) Addrs() ([]net.Addr, error) { return i.addrs, nil }
type testAddr struct {
network string
str string
}
func (a *testAddr) Network() string { return a.network }
func (a *testAddr) String() string { return a.str }
type testDnsConn struct {
wantName string
ipv4 []byte

166
dialer_integration_test.go Normal file
View file

@ -0,0 +1,166 @@
//go:build integration
package multinet
import (
"net"
"net/netip"
"runtime"
"testing"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
)
func TestDialer(t *testing.T) {
runInNewNamespace(t, "2 interfaces with multiple routes in different subnets", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.10/23", "4.4.4.4/8"},
"testdev2": {"1.2.30.11/23", "4.4.4.5/8"},
})
// Do not use `t.Run` because everything should be executed in a single OS thread.
{ // Restrict to a single subnet.
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("1.2.30.0/23"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 10}}},
{Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}},
},
},
}, d.(*dialer).subnets)
}
{ // Restrict to two subnets.
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23", "4.0.0.0/8"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("1.2.30.0/23"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 10}}},
{Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}},
},
},
{
Mask: netip.MustParsePrefix("4.0.0.0/8"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: &net.TCPAddr{IP: net.IP{4, 4, 4, 4}}},
{Name: "testdev2", LocalAddr: &net.TCPAddr{IP: net.IP{4, 4, 4, 5}}},
},
},
}, d.(*dialer).subnets)
}
})
runInNewNamespace(t, "4 interfaces, 2 for data, 2 internal", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"internal1": {"192.168.0.1/16"},
"internal2": {"192.168.0.2/16"},
"data1": {"10.11.12.101/24"},
"data2": {"10.11.12.102/24"},
})
d, err := NewDialer(Config{
Subnets: []string{"10.11.12.0/24", "192.168.0.0/16"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("10.11.12.0/24"),
Interfaces: []Source{
{Name: "data1", LocalAddr: &net.TCPAddr{IP: net.IP{10, 11, 12, 101}}},
{Name: "data2", LocalAddr: &net.TCPAddr{IP: net.IP{10, 11, 12, 102}}},
},
},
{
Mask: netip.MustParsePrefix("192.168.0.0/16"),
Interfaces: []Source{
{Name: "internal1", LocalAddr: &net.TCPAddr{IP: net.IP{192, 168, 0, 1}}},
{Name: "internal2", LocalAddr: &net.TCPAddr{IP: net.IP{192, 168, 0, 2}}},
},
},
}, d.(*dialer).subnets)
})
runInNewNamespace(t, "with ipv6", func(t *testing.T, ns netns.NsHandle) {
addr1 := "2001:db8:85a3:8d3:1319:8a2e:370:7348/64"
addr2 := "2001:db8:85a3:8d3:1319:8a2e:370:8192/64"
setup(t, map[string][]string{
"testdev1": {addr1},
"testdev2": {addr2},
})
// Do not use `t.Run` because everything should be executed in a single OS thread.
{ // Restrict to a single subnet.
d, err := NewDialer(Config{
Subnets: []string{"2001:db8:85a3:8d3::/64"},
})
require.NoError(t, err)
require.Equal(t, []Subnet{
{
Mask: netip.MustParsePrefix("2001:db8:85a3:8d3::/64"),
Interfaces: []Source{
{Name: "testdev1", LocalAddr: mustParseIPv6(t, addr1)},
{Name: "testdev2", LocalAddr: mustParseIPv6(t, addr2)},
},
},
}, d.(*dialer).subnets)
}
})
}
func mustParseIPv6(t *testing.T, s string) *net.TCPAddr {
ip, _, err := net.ParseCIDR(s)
require.NoError(t, err)
return &net.TCPAddr{IP: ip}
}
func setup(t *testing.T, config map[string][]string) {
for name, ips := range config {
link := createLink(t, name)
for i := range ips {
ip, err := netlink.ParseIPNet(ips[i])
require.NoError(t, err)
require.NoError(t, netlink.AddrAdd(link, &netlink.Addr{IPNet: ip}))
}
}
}
func createLink(t *testing.T, name string) netlink.Link {
require.NoError(t, netlink.LinkAdd(&netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: name}}))
link, err := netlink.LinkByName(name)
require.NoError(t, err)
require.NoError(t, netlink.LinkSetUp(link))
return link
}
func runInNewNamespace(t *testing.T, name string, f func(t *testing.T, ns netns.NsHandle)) {
t.Run(name, func(t *testing.T) {
// To avoid messing with host network settings,
// we create a new names space and execute tests in it.
// Switching thread can move us to a different namespace, thus this line.
runtime.LockOSThread()
defer runtime.UnlockOSThread()
origns, err := netns.Get()
require.NoError(t, err)
defer origns.Close()
defer netns.Set(origns)
newns, err := netns.New()
require.NoError(t, err)
defer newns.Close()
f(t, newns)
})
}

View file

@ -1,39 +0,0 @@
package multinet
import (
"context"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
)
func TestNoSourceIPs(t *testing.T) {
t.Run("noop balancer", func(t *testing.T) {
d, err := NewDialer(Config{
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("10.11.12.0/24"),
},
},
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "10.11.12.254:8080")
require.ErrorIs(t, err, errNoSuitableNodeFound)
require.Nil(t, conn)
})
t.Run("round robin balancer", func(t *testing.T) {
d, err := NewDialer(Config{
Subnets: []Subnet{
{
Prefix: netip.MustParsePrefix("10.11.12.0/24"),
},
},
Balancer: BalancerTypeRoundRobin,
})
require.NoError(t, err)
conn, err := d.DialContext(context.Background(), "tcp", "10.11.12.254:8080")
require.ErrorIs(t, err, errNoSuitableNodeFound)
require.Nil(t, conn)
})
}

9
go.mod
View file

@ -1,10 +1,13 @@
module git.frostfs.info/TrueCloudLab/multinet
go 1.22
go 1.21.0
require (
github.com/stretchr/testify v1.9.0
golang.org/x/net v0.26.0
github.com/stretchr/testify v1.8.4
github.com/vishvananda/netlink v1.1.0
github.com/vishvananda/netns v0.0.4
golang.org/x/net v0.17.0
golang.org/x/sys v0.13.0
)
require (

16
go.sum
View file

@ -2,10 +2,18 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/vishvananda/netlink v1.1.0 h1:1iyaYNBLmP6L0220aDnYQpo1QEV4t4hJ+xEEhhJH8j0=
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

73
health.go Normal file
View file

@ -0,0 +1,73 @@
package multinet
import (
"net/netip"
"sync"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
"golang.org/x/sys/unix"
)
type NetlinkWatcher struct {
d Multidialer
linkUpdates chan netlink.LinkUpdate
addrUpdates chan netlink.AddrUpdate
done chan struct{}
wg sync.WaitGroup
}
func NewNetlinkWatcher(d Multidialer) *NetlinkWatcher {
return &NetlinkWatcher{
d: d,
addrUpdates: make(chan netlink.AddrUpdate, 1),
linkUpdates: make(chan netlink.LinkUpdate, 1),
done: make(chan struct{}),
}
}
func (w *NetlinkWatcher) Start() error {
ns, err := netns.Get()
if err != nil {
return err
}
if err := netlink.LinkSubscribe(w.linkUpdates, w.done); err != nil {
return err
}
if err := netlink.AddrSubscribe(w.addrUpdates, w.done); err != nil {
close(w.done)
return err
}
w.wg.Add(1)
go w.watch(ns)
return nil
}
func (w *NetlinkWatcher) watch(ns netns.NsHandle) {
defer w.wg.Done()
for {
select {
case <-w.done:
return
case update := <-w.addrUpdates:
// Wont work if an multiple interfaces share IP address.
// Should not happen in practice.
ip, ok := netip.AddrFromSlice(update.LinkAddress.IP)
if !ok {
continue
}
w.d.UpdateInterface("", ip, update.NewAddr)
case update := <-w.linkUpdates:
up := update.Flags&unix.IFF_UP != 0
w.d.UpdateInterface(update.Link.Attrs().Name, netip.Addr{}, up)
}
}
}
func (w *NetlinkWatcher) Stop() {
close(w.done)
w.wg.Wait()
}

157
health_integration_test.go Normal file
View file

@ -0,0 +1,157 @@
//go:build integration
package multinet
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
)
func Test_NetlinkWatcher(t *testing.T) {
runInNewNamespace(t, "noop balancer, disable interface", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.11/23"},
"testdev2": {"1.2.30.12/23"},
})
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
w := NewNetlinkWatcher(d)
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
link, err := netlink.LinkByName("testdev1")
require.NoError(t, err)
require.NoError(t, netlink.LinkSetDown(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr2)
require.NoError(t, netlink.LinkSetUp(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr1)
})
runInNewNamespace(t, "noop balancer, remove address", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.11/23"},
"testdev2": {"1.2.30.12/23"},
})
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
w := NewNetlinkWatcher(d)
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr1)
link, err := netlink.LinkByName("testdev1")
require.NoError(t, err)
ip, err := netlink.ParseIPNet("1.2.30.11/23")
require.NoError(t, err)
require.NoError(t, netlink.AddrDel(link, &netlink.Addr{IPNet: ip}))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr2)
require.NoError(t, netlink.AddrAdd(link, &netlink.Addr{IPNet: ip}))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr1)
})
runInNewNamespace(t, "round-robin balancer, disable interface", func(t *testing.T, ns netns.NsHandle) {
setup(t, map[string][]string{
"testdev1": {"1.2.30.11/23"},
"testdev2": {"1.2.30.12/23"},
})
addr1 := &net.TCPAddr{IP: net.IP{1, 2, 30, 11}}
addr2 := &net.TCPAddr{IP: net.IP{1, 2, 30, 12}}
result := make(chan net.Addr, 1)
d, err := NewDialer(Config{
Subnets: []string{"1.2.30.0/23"},
Balancer: BalancerTypeRoundRobin,
DialContext: func(d *net.Dialer, _ context.Context, _, _ string) (net.Conn, error) {
result <- d.LocalAddr
return nil, nil
},
})
require.NoError(t, err)
w := NewNetlinkWatcher(d)
require.NoError(t, w.Start())
t.Cleanup(w.Stop)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr2)
link, err := netlink.LinkByName("testdev1")
require.NoError(t, err)
require.NoError(t, netlink.LinkSetDown(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr2)
checkDialAddr(t, d, result, addr2)
require.NoError(t, netlink.LinkSetUp(link))
time.Sleep(time.Second)
checkDialAddr(t, d, result, addr1)
checkDialAddr(t, d, result, addr2)
})
}
func checkDialAddr(t *testing.T, d Multidialer, ch chan net.Addr, expected net.Addr) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
_, err := d.DialContext(ctx, "tcp", "1.2.30.42:12345")
require.NoError(t, err)
select {
case addr := <-ch:
require.Equal(t, expected, addr)
default:
require.Fail(t, "DialContext() was not called")
}
}

33
interface.go Normal file
View file

@ -0,0 +1,33 @@
package multinet
import "net"
// Interface provides information about net.Interface.
type Interface interface {
Name() string
Addrs() ([]net.Addr, error)
}
type netInterface struct {
iface net.Interface
}
func (i *netInterface) Name() string {
return i.iface.Name
}
func (i *netInterface) Addrs() ([]net.Addr, error) {
return i.iface.Addrs()
}
func systemInterfaces() ([]Interface, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var result []Interface
for _, iface := range ifaces {
result = append(result, &netInterface{iface: iface})
}
return result, nil
}