diff --git a/dsn.go b/dsn.go index 6ce5cc020..af3dfa303 100644 --- a/dsn.go +++ b/dsn.go @@ -28,7 +28,9 @@ var ( errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") ) -// Config is a configuration parsed from a DSN string +// Config is a configuration parsed from a DSN string. +// If a new Config is created instead of being parsed from a DSN string, +// the NewConfig function should be used, which sets default values. type Config struct { User string // Username Passwd string // Password (requires User) @@ -57,6 +59,43 @@ type Config struct { RejectReadOnly bool // Reject read-only connections } +// NewConfig creates a new Config and sets default values. +func NewConfig() *Config { + return &Config{ + Collation: defaultCollation, + Loc: time.UTC, + AllowNativePasswords: true, + } +} + +func (cfg *Config) normalize() error { + if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + return errInvalidDSNUnsafeCollation + } + + // Set default network if empty + if cfg.Net == "" { + cfg.Net = "tcp" + } + + // Set default address if empty + if cfg.Addr == "" { + switch cfg.Net { + case "tcp": + cfg.Addr = "127.0.0.1:3306" + case "unix": + cfg.Addr = "/tmp/mysql.sock" + default: + return errors.New("default addr for network '" + cfg.Net + "' unknown") + } + + } else if cfg.Net == "tcp" { + cfg.Addr = ensureHavePort(cfg.Addr) + } + + return nil +} + // FormatDSN formats the given Config into a DSN string which can be passed to // the driver. func (cfg *Config) FormatDSN() string { @@ -273,11 +312,7 @@ func (cfg *Config) FormatDSN() string { // ParseDSN parses the DSN string to a Config func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values - cfg = &Config{ - Loc: time.UTC, - Collation: defaultCollation, - AllowNativePasswords: true, - } + cfg = NewConfig() // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') @@ -345,31 +380,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) { return nil, errInvalidDSNNoSlash } - if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { - return nil, errInvalidDSNUnsafeCollation - } - - // Set default network if empty - if cfg.Net == "" { - cfg.Net = "tcp" + if err = cfg.normalize(); err != nil { + return nil, err } - - // Set default address if empty - if cfg.Addr == "" { - switch cfg.Net { - case "tcp": - cfg.Addr = "127.0.0.1:3306" - case "unix": - cfg.Addr = "/tmp/mysql.sock" - default: - return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") - } - - } - if cfg.Net == "tcp" { - cfg.Addr = ensureHavePort(cfg.Addr) - } - return } diff --git a/dsn_test.go b/dsn_test.go index 01b57212f..af28da351 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -98,6 +98,7 @@ func TestDSNParserInvalid(t *testing.T) { "(/", // no closing brace "net(addr)//", // unescaped "User:pass@tcp(1.2.3.4:3306)", // no trailing slash + "net()/", // unknown default addr //"/dbname?arg=/some/unescaped/path", }