diff --git a/connector.go b/connector.go index 62012dba..769b3adc 100644 --- a/connector.go +++ b/connector.go @@ -87,20 +87,25 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { mc.parseTime = mc.cfg.ParseTime // Connect to Server - dialsLock.RLock() - dial, ok := dials[mc.cfg.Net] - dialsLock.RUnlock() - if ok { - dctx := ctx - if mc.cfg.Timeout > 0 { - var cancel context.CancelFunc - dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) - defer cancel() - } - mc.netConn, err = dial(dctx, mc.cfg.Addr) + dctx := ctx + if mc.cfg.Timeout > 0 { + var cancel context.CancelFunc + dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) + defer cancel() + } + + if c.cfg.DialFunc != nil { + mc.netConn, err = c.cfg.DialFunc(dctx, mc.cfg.Net, mc.cfg.Addr) } else { - nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { + mc.netConn, err = dial(dctx, mc.cfg.Addr) + } else { + nd := net.Dialer{} + mc.netConn, err = nd.DialContext(dctx, mc.cfg.Net, mc.cfg.Addr) + } } if err != nil { return nil, err diff --git a/dsn.go b/dsn.go index 3c7a6e21..f391a8fc 100644 --- a/dsn.go +++ b/dsn.go @@ -55,6 +55,8 @@ type Config struct { ReadTimeout time.Duration // I/O read timeout WriteTimeout time.Duration // I/O write timeout Logger Logger // Logger + // DialFunc specifies the dial function for creating connections + DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // boolean fields