diff --git a/dialer.go b/dialer.go index f027d08..d0efaff 100644 --- a/dialer.go +++ b/dialer.go @@ -39,6 +39,8 @@ type dialer struct { resolver net.Resolver // See Config.FallbackDelay description. fallbackDelay time.Duration + // Event handler. + eh EventHandler } // Subnet represents a single subnet, possibly routable from multiple source IPs. @@ -47,6 +49,10 @@ type Subnet struct { SourceIPs []netip.Addr } +type EventHandler interface { + DialPerformed(sourceIP net.Addr, network, address string, err error) +} + // Config contains Multidialer configuration. type Config struct { // Routable subnets. @@ -71,6 +77,8 @@ type Config struct { // DialContext is custom DialContext function. // If not specified, default implemenattion is used (`d.DialContext(ctx, network, address)`). DialContext func(d *net.Dialer, ctx context.Context, network, address string) (net.Conn, error) + // EventHandler defines event handler. + EventHandler EventHandler } // NewDialer ... @@ -98,6 +106,12 @@ func NewDialer(c Config) (Dialer, error) { d.customDialContext = c.DialContext } + if c.EventHandler != nil { + d.eh = c.EventHandler + } else { + d.eh = noopEventHandler{} + } + return &d, nil } @@ -279,10 +293,15 @@ func (d *dialer) dialAddr(ctx context.Context, network, address string, addr net } func (d *dialer) dialContext(ctx context.Context, nd *net.Dialer, network, address string) (net.Conn, error) { + var conn net.Conn + var err error if h := d.customDialContext; h != nil { - return h(nd, ctx, network, address) + conn, err = h(nd, ctx, network, address) + } else { + conn, err = nd.DialContext(ctx, network, address) } - return nd.DialContext(ctx, network, address) + d.eh.DialPerformed(nd.LocalAddr, network, address, err) + return conn, err } // splitByType divides an address list into two categories: @@ -302,3 +321,7 @@ func splitByType(addrs []netip.AddrPort) (primaries []netip.AddrPort, fallbacks } return } + +type noopEventHandler struct{} + +func (s noopEventHandler) DialPerformed(net.Addr, string, string, error) {}