diff --git a/AUTHORS b/AUTHORS index 876b2964a..ff6c2aec8 100644 --- a/AUTHORS +++ b/AUTHORS @@ -16,6 +16,7 @@ Achille Roussel Alex Snast Alexey Palazhchenko Andrew Reid +Andy Grunwald Animesh Ray Arne Hormann Ariel Mashraki @@ -95,6 +96,7 @@ Tan Jinhua <312841925 at qq.com> Thomas Wodarek Tim Ruffles Tom Jenkinson +Vasily Fedoseyev Vladimir Kovpak Vladyslav Zhelezniak Xiangyu Hu diff --git a/README.md b/README.md index ded6e3b16..c32f88d99 100644 --- a/README.md +++ b/README.md @@ -242,6 +242,16 @@ SELECT u.id FROM users as u will return `u.id` instead of just `id` if `columnsWithAlias=true`. +##### `connectAttrs` + +``` +Type: map +Valid Values: comma-separated list of attribute:value pairs +Default: empty +``` + +Allows setting of connection attributes, for example `connectAttrs=program_name:YourProgramName` will show `YourProgramName` in `Program` field of connections list of Mysql Workbench, if your server supports it (requires `performance_schema` to be supported and enabled). + ##### `interpolateParams` ``` diff --git a/const.go b/const.go index b1e6b85ef..6118a5e07 100644 --- a/const.go +++ b/const.go @@ -46,7 +46,7 @@ const ( clientIgnoreSIGPIPE clientTransactions clientReserved - clientSecureConn + clientSecureConn // reserved2 in 8.0 clientMultiStatements clientMultiResults clientPSMultiResults @@ -56,6 +56,8 @@ const ( clientCanHandleExpiredPasswords clientSessionTrack clientDeprecateEOF + clientSslVerifyServerCert clientFlag = 1 << 30 + clientRememberOptions clientFlag = 1 << 31 ) const ( diff --git a/driver_test.go b/driver_test.go index 4850498d0..6a0ee9bd3 100644 --- a/driver_test.go +++ b/driver_test.go @@ -2189,6 +2189,59 @@ func TestEmptyPassword(t *testing.T) { } } +func TestConnectAttrs(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + db, err := sql.Open("mysql", dsn+"&connectAttrs=program_name:GoTest,foo:bar") + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + dbt := &DBTest{t, db} + + rows := dbt.mustQuery("SHOW VARIABLES LIKE 'performance_schema'") + if rows.Next() { + var var_name, value string + rows.Scan(&var_name, &value) + if value != "ON" { + t.Skip("performance_schema is disabled") + } + } else { + t.Skip("no performance_schema variable in mysql") + } + rows.Close() + + rows, err = dbt.db.Query("SELECT attr_value FROM performance_schema.session_connect_attrs WHERE processlist_id=CONNECTION_ID() AND attr_name='program_name'") + if err != nil { + dbt.Skipf("server probably does not support performance_schema.session_connect_attrs: %s", err) + } + + if rows.Next() { + var str string + rows.Scan(&str) + if "GoTest" != str { + dbt.Errorf("GoTest != %s", str) + } + } else { + dbt.Error("no data for program_name") + } + rows.Close() + + rows = dbt.mustQuery("SELECT attr_value FROM performance_schema.session_connect_attrs WHERE processlist_id=CONNECTION_ID() AND attr_name='foo'") + if rows.Next() { + var str string + rows.Scan(&str) + if "bar" != str { + dbt.Errorf("bar != %s", str) + } + } else { + dbt.Error("no data for custom attribute") + } + rows.Close() +} + // static interface implementation checks of mysqlConn var ( _ driver.ConnBeginTx = &mysqlConn{} diff --git a/dsn.go b/dsn.go index a306d66a3..7227ae7f5 100644 --- a/dsn.go +++ b/dsn.go @@ -40,6 +40,7 @@ type Config struct { Addr string // Network address (requires Net) DBName string // Database name Params map[string]string // Connection parameters + ConnectAttrs map[string]string // Connection attributes Collation string // Connection collation Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed @@ -272,6 +273,30 @@ func (cfg *Config) FormatDSN() string { writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket)) } + if len(cfg.ConnectAttrs) > 0 { + // connectAttrs=program_name:Login Server,other_name:other + if hasParam { + buf.WriteString("&connectAttrs=") + } else { + hasParam = true + buf.WriteString("?connectAttrs=") + } + + var attr_names []string + for attr_name := range cfg.ConnectAttrs { + attr_names = append(attr_names, attr_name) + } + sort.Strings(attr_names) + for index, attr_name := range attr_names { + if index > 0 { + buf.WriteByte(',') + } + buf.WriteString(attr_name) + buf.WriteByte(':') + buf.WriteString(url.QueryEscape(cfg.ConnectAttrs[attr_name])) + } + } + // other params if cfg.Params != nil { var params []string @@ -536,6 +561,24 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return } + case "connectAttrs": + if cfg.ConnectAttrs == nil { + cfg.ConnectAttrs = make(map[string]string) + } + + var ConnectAttrs string + if ConnectAttrs, err = url.QueryUnescape(value); err != nil { + return + } + + // program_name:Name,foo:bar + for _, attr_str := range strings.Split(ConnectAttrs, ",") { + attr := strings.SplitN(attr_str, ":", 2) + if len(attr) != 2 { + continue + } + cfg.ConnectAttrs[attr[0]] = attr[1] + } default: // lazy init if cfg.Params == nil { diff --git a/dsn_test.go b/dsn_test.go index fc6eea9c8..c34512c48 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -71,6 +71,9 @@ var testDSNs = []struct { }, { "tcp(de:ad:be:ef::ca:fe)/dbname", &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "tcp(127.0.0.1)/dbname?connectAttrs=program_name:SomeService", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", ConnectAttrs: map[string]string{"program_name": "SomeService"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, }, } @@ -403,6 +406,20 @@ func TestNormalizeTLSConfig(t *testing.T) { } } +func TestAttributesAreSorted(t *testing.T) { + expected := "/dbname?connectAttrs=p1:v1,p2:v2" + cfg := NewConfig() + cfg.DBName = "dbname" + cfg.ConnectAttrs = map[string]string{ + "p2": "v2", + "p1": "v1", + } + actual := cfg.FormatDSN() + if actual != expected { + t.Errorf("generic Config.ConnectAttrs were not sorted: want %#v, got %#v", expected, actual) + } +} + func BenchmarkParseDSN(b *testing.B) { b.ReportAllocs() diff --git a/packets.go b/packets.go index 1867ecab2..9ac62fae2 100644 --- a/packets.go +++ b/packets.go @@ -18,6 +18,10 @@ import ( "fmt" "io" "math" + "os" + "runtime" + "strconv" + "strings" "time" ) @@ -235,10 +239,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if len(data) > pos { // character set [1 byte] // status flags [2 bytes] + pos += 1 + 2 + // capability flags (upper 2 bytes) [2 bytes] + mc.flags |= clientFlag(uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16) + pos += 2 + // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] - pos += 1 + 2 + 2 + 1 + 10 + pos += 1 + 10 // second part of the password cipher [mininum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -312,9 +321,42 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string } pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 + if clientFlags&clientSecureConn == 0 || clientFlags&clientPluginAuthLenEncClientData == 0 { + pktLen++ + } + + connectAttrsBuf := make([]byte, 0, 100) + if mc.flags&clientConnectAttrs != 0 { + clientFlags |= clientConnectAttrs + + // Set default connection attributes + // See https://dev.mysql.com/doc/refman/8.0/en/performance-schema-connection-attribute-tables.html#performance-schema-connection-attributes-available + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_client_name")) + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("github.com/go-sql-driver/mysql")) + + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_os")) + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(runtime.GOOS)) + + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_platform")) + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(runtime.GOARCH)) + + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_pid")) + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(strconv.Itoa(os.Getpid()))) + + for k, v := range mc.cfg.ConnectAttrs { + if strings.HasPrefix(k, "_") { + return errors.New("connection attributes cannot start with '_'. They are reserved for internal usage") + } + + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(k)) + connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(v)) + } + connectAttrsBuf = appendLengthEncodedString(make([]byte, 0, 100), connectAttrsBuf) + pktLen += len(connectAttrsBuf) + } // To specify a db name - if n := len(mc.cfg.DBName); n > 0 { + if n := len(mc.cfg.DBName); mc.flags&clientConnectWithDB != 0 && n > 0 { clientFlags |= clientConnectWithDB pktLen += n + 1 } @@ -380,20 +422,39 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string data[pos] = 0x00 pos++ - // Auth Data [length encoded integer] - pos += copy(data[pos:], authRespLEI) + // Auth Data [length encoded integer + data] if clientPluginAuthLenEncClientData + // clientSecureConn => 1 byte len + data + // else null-terminated + if clientFlags&clientPluginAuthLenEncClientData != 0 { + pos += copy(data[pos:], authRespLEI) + } else if clientFlags&clientSecureConn != 0 { + data[pos] = uint8(len(authResp)) + pos++ + } pos += copy(data[pos:], authResp) + if clientFlags&clientSecureConn == 0 && clientFlags&clientPluginAuthLenEncClientData == 0 { + data[pos] = 0x00 + pos++ + } // Databasename [null terminated string] - if len(mc.cfg.DBName) > 0 { + if clientFlags&clientConnectWithDB != 0 { pos += copy(data[pos:], mc.cfg.DBName) data[pos] = 0x00 pos++ } - pos += copy(data[pos:], plugin) - data[pos] = 0x00 - pos++ + // auth plugin name [null terminated string] + if clientFlags&clientPluginAuth != 0 { + pos += copy(data[pos:], plugin) + data[pos] = 0x00 + pos++ + } + + // connection attributes [lenenc-int total + lenenc-str key-value pairs] + if clientFlags&clientConnectAttrs != 0 { + pos += copy(data[pos:], connectAttrsBuf) + } // Send Auth packet return mc.writePacket(data[:pos]) diff --git a/utils.go b/utils.go index bcdee1b46..d303f9dc0 100644 --- a/utils.go +++ b/utils.go @@ -579,6 +579,12 @@ func skipLengthEncodedString(b []byte) (int, error) { return n, io.EOF } +// encodes a bytes slice with prepended length-encoded size and appends it to the given bytes slice +func appendLengthEncodedString(b []byte, str []byte) []byte { + b = appendLengthEncodedInteger(b, uint64(len(str))) + return append(b, str...) +} + // returns the number read, whether the value is NULL and the number of bytes read func readLengthEncodedInteger(b []byte) (uint64, bool, int) { // See issue #349