diff --git a/rpc/client/call_options.go b/rpc/client/call_options.go index 8134e46..4fe8791 100644 --- a/rpc/client/call_options.go +++ b/rpc/client/call_options.go @@ -2,13 +2,16 @@ package client import ( "context" + + "google.golang.org/grpc" ) // CallOption is a messaging session option within Protobuf RPC. type CallOption func(*callParameters) type callParameters struct { - ctx context.Context // nolint:containedctx + ctx context.Context // nolint:containedctx + dialer func(context.Context, grpc.ClientConnInterface) error } func defaultCallParameters() *callParameters { @@ -27,3 +30,11 @@ func WithContext(ctx context.Context) CallOption { prm.ctx = ctx } } + +// WithDialer returns option to specify grpc dialer. If passed, it will be +// called after the connection is successfully created. +func WithDialer(dialer func(context.Context, grpc.ClientConnInterface) error) CallOption { + return func(prm *callParameters) { + prm.dialer = dialer + } +} diff --git a/rpc/client/connect.go b/rpc/client/connect.go index 29f4189..e22e0a6 100644 --- a/rpc/client/connect.go +++ b/rpc/client/connect.go @@ -12,7 +12,7 @@ import ( var errInvalidEndpoint = errors.New("invalid endpoint options") -func (c *Client) openGRPCConn(ctx context.Context) error { +func (c *Client) openGRPCConn(ctx context.Context, dialer func(ctx context.Context, cc grpcstd.ClientConnInterface) error) error { if c.conn != nil { return nil } @@ -21,15 +21,21 @@ func (c *Client) openGRPCConn(ctx context.Context) error { return errInvalidEndpoint } - dialCtx, cancel := context.WithTimeout(ctx, c.dialTimeout) var err error - c.conn, err = grpcstd.DialContext(dialCtx, c.addr, c.grpcDialOpts...) - - cancel() - + c.conn, err = grpcstd.NewClient(c.addr, c.grpcDialOpts...) if err != nil { - return fmt.Errorf("gRPC dial: %w", err) + return fmt.Errorf("gRPC new client: %w", err) + } + + if dialer != nil { + ctx, cancel := context.WithTimeout(ctx, c.dialTimeout) + defer cancel() + + if err := dialer(ctx, c.conn); err != nil { + _ = c.conn.Close() + return fmt.Errorf("gRPC dial: %w", err) + } } return nil diff --git a/rpc/client/init.go b/rpc/client/init.go index 60ccda9..be8d066 100644 --- a/rpc/client/init.go +++ b/rpc/client/init.go @@ -46,7 +46,7 @@ func (c *Client) Init(info common.CallMethodInfo, opts ...CallOption) (MessageRe opt(prm) } - if err := c.openGRPCConn(prm.ctx); err != nil { + if err := c.openGRPCConn(prm.ctx, prm.dialer); err != nil { return nil, err } diff --git a/rpc/client/options.go b/rpc/client/options.go index 22358a3..5711cd4 100644 --- a/rpc/client/options.go +++ b/rpc/client/options.go @@ -37,7 +37,6 @@ func (c *cfg) initDefault() { c.dialTimeout = defaultDialTimeout c.rwTimeout = defaultRWTimeout c.grpcDialOpts = []grpc.DialOption{ - grpc.WithBlock(), grpc.WithTransportCredentials(insecure.NewCredentials()), } }