package net

import (
	"context"
	"net"
	"sync"

	"git.frostfs.info/TrueCloudLab/multinet"
)

type DialerSource struct {
	guard sync.RWMutex

	c Config

	md multinet.Dialer
}

func NewDialerSource(c Config) (*DialerSource, error) {
	result := &DialerSource{}
	if err := result.build(c); err != nil {
		return nil, err
	}
	return result, nil
}

func (s *DialerSource) build(c Config) error {
	if c.Enabled {
		mc, err := c.toMultinetConfig()
		if err != nil {
			return err
		}
		md, err := multinet.NewDialer(mc)
		if err != nil {
			return err
		}
		s.md = md
		s.c = c
		return nil
	}
	s.md = nil
	s.c = c
	return nil
}

// GrpcContextDialer returns grpc.WithContextDialer func.
// Returns nil if multinet disabled.
func (s *DialerSource) GrpcContextDialer() func(context.Context, string) (net.Conn, error) {
	s.guard.RLock()
	defer s.guard.RUnlock()

	if s.c.Enabled {
		return func(ctx context.Context, address string) (net.Conn, error) {
			network, address := parseDialTarget(address)
			return s.md.DialContext(ctx, network, address)
		}
	}
	return nil
}

// NetContextDialer returns net.DialContext dial function.
// Returns nil if multinet disabled.
func (s *DialerSource) NetContextDialer() func(context.Context, string, string) (net.Conn, error) {
	s.guard.RLock()
	defer s.guard.RUnlock()

	if s.c.Enabled {
		return func(ctx context.Context, network, address string) (net.Conn, error) {
			return s.md.DialContext(ctx, network, address)
		}
	}
	return nil
}

func (s *DialerSource) Update(c Config) error {
	s.guard.Lock()
	defer s.guard.Unlock()

	if s.c.equals(c) {
		return nil
	}
	return s.build(c)
}