Skip to content

Commit c6deadb

Browse files
committed
Support for sending connection attributes
1 parent 2cc627a commit c6deadb

File tree

7 files changed

+145
-1
lines changed

7 files changed

+145
-1
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Shuode Li <elemount at qq.com>
6666
Soroush Pour <me at soroushjp.com>
6767
Stan Putrya <root.vagner at gmail.com>
6868
Stanley Gunawan <gunawan.stanley at gmail.com>
69+
Vasily Fedoseyev <vasilyfedoseyev at gmail.com>
6970
Xiangyu Hu <xiangyu.hu at outlook.com>
7071
Xiaobing Jiang <s7v7nislands at gmail.com>
7172
Xiuming Chen <cc at cxm.cc>

README.md

+10
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,16 @@ SELECT u.id FROM users as u
204204

205205
will return `u.id` instead of just `id` if `columnsWithAlias=true`.
206206

207+
##### `connectionAttributes`
208+
209+
```
210+
Type: map
211+
Valid Values: comma-separated list of attribute:value pairs
212+
Default: empty
213+
```
214+
215+
Allows setting of connection attributes, for example `connectionAttributes=program_name:YourProgramName` will show `YourProgramName` in `Program` field of connections list of Mysql Workbench.
216+
207217
##### `interpolateParams`
208218

209219
```

driver_test.go

+39
Original file line numberDiff line numberDiff line change
@@ -2018,3 +2018,42 @@ func TestPing(t *testing.T) {
20182018
}
20192019
})
20202020
}
2021+
2022+
func TestConnectionAttributes(t *testing.T) {
2023+
if !available {
2024+
t.Skipf("MySQL server not running on %s", netAddr)
2025+
}
2026+
2027+
db, err := sql.Open("mysql", dsn+"&connectionAttributes=program_name:GoTest,foo:bar")
2028+
if err != nil {
2029+
t.Fatalf("error connecting: %s", err.Error())
2030+
}
2031+
defer db.Close()
2032+
dbt := &DBTest{t, db}
2033+
2034+
rows, err := dbt.db.Query("SELECT program_name FROM sys.processlist WHERE db=?", dbname)
2035+
if err != nil {
2036+
dbt.Skip("server probably does not support program_name in sys.processlist")
2037+
}
2038+
2039+
if rows.Next() {
2040+
var str string
2041+
rows.Scan(&str)
2042+
if "GoTest" != str {
2043+
dbt.Errorf("GoTest != %s", str)
2044+
}
2045+
} else {
2046+
dbt.Error("no data")
2047+
}
2048+
2049+
rows = dbt.mustQuery("select attr_value from performance_schema.session_account_connect_attrs where processlist_id=CONNECTION_ID() and attr_name='foo'")
2050+
if rows.Next() {
2051+
var str string
2052+
rows.Scan(&str)
2053+
if "bar" != str {
2054+
dbt.Errorf("bar != %s", str)
2055+
}
2056+
} else {
2057+
dbt.Error("no data")
2058+
}
2059+
}

dsn.go

+43
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ type Config struct {
3838
Addr string // Network address (requires Net)
3939
DBName string // Database name
4040
Params map[string]string // Connection parameters
41+
Attributes map[string]string // Connection attributes
4142
Collation string // Connection collation
4243
Loc *time.Location // Location for time.Time values
4344
MaxAllowedPacket int // Max packet size allowed
@@ -295,6 +296,30 @@ func (cfg *Config) FormatDSN() string {
295296

296297
}
297298

299+
if len(cfg.Attributes) > 0 {
300+
// connectionAttributes=program_name:Login Server,other_name:other
301+
if hasParam {
302+
buf.WriteString("&connectionAttributes=")
303+
} else {
304+
hasParam = true
305+
buf.WriteString("?connectionAttributes=")
306+
}
307+
308+
var attr_names []string
309+
for attr_name := range cfg.Attributes {
310+
attr_names = append(attr_names, attr_name)
311+
}
312+
sort.Strings(attr_names)
313+
for index, attr_name := range attr_names {
314+
if index > 0 {
315+
buf.WriteByte(',')
316+
}
317+
buf.WriteString(attr_name)
318+
buf.WriteByte(':')
319+
buf.WriteString(url.QueryEscape(cfg.Attributes[attr_name]))
320+
}
321+
}
322+
298323
// other params
299324
if cfg.Params != nil {
300325
var params []string
@@ -561,6 +586,24 @@ func parseDSNParams(cfg *Config, params string) (err error) {
561586
if err != nil {
562587
return
563588
}
589+
case "connectionAttributes":
590+
if cfg.Attributes == nil {
591+
cfg.Attributes = make(map[string]string)
592+
}
593+
594+
var attributes string
595+
if attributes, err = url.QueryUnescape(value); err != nil {
596+
return
597+
}
598+
599+
// program_name:Name,foo:bar
600+
for _, attr_str := range strings.Split(attributes, ",") {
601+
attr := strings.SplitN(attr_str, ":", 2)
602+
if len(attr) != 2 {
603+
continue
604+
}
605+
cfg.Attributes[attr[0]] = attr[1]
606+
}
564607
default:
565608
// lazy init
566609
if cfg.Params == nil {

dsn_test.go

+17
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ var testDSNs = []struct {
7171
}, {
7272
"tcp(de:ad:be:ef::ca:fe)/dbname",
7373
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
74+
}, {
75+
"tcp(127.0.0.1)/dbname?connectionAttributes=program_name:SomeService",
76+
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Attributes: map[string]string{"program_name": "SomeService"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
7477
},
7578
}
7679

@@ -274,6 +277,20 @@ func TestParamsAreSorted(t *testing.T) {
274277
}
275278
}
276279

280+
func TestAttributesAreSorted(t *testing.T) {
281+
expected := "/dbname?connectionAttributes=p1:v1,p2:v2"
282+
cfg := NewConfig()
283+
cfg.DBName = "dbname"
284+
cfg.Attributes = map[string]string{
285+
"p2": "v2",
286+
"p1": "v1",
287+
}
288+
actual := cfg.FormatDSN()
289+
if actual != expected {
290+
t.Errorf("generic Config.Attributes were not sorted: want %#v, got %#v", expected, actual)
291+
}
292+
}
293+
277294
func BenchmarkParseDSN(b *testing.B) {
278295
b.ReportAllocs()
279296

packets.go

+29-1
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,15 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
196196
if len(data) > pos {
197197
// character set [1 byte]
198198
// status flags [2 bytes]
199+
pos += 1 + 2
200+
199201
// capability flags (upper 2 bytes) [2 bytes]
202+
mc.flags |= clientFlag(uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16)
203+
pos += 2
204+
200205
// length of auth-plugin-data [1 byte]
201206
// reserved (all [00]) [10 bytes]
202-
pos += 1 + 2 + 2 + 1 + 10
207+
pos += 1 + 10
203208

204209
// second part of the password cipher [mininum 13 bytes],
205210
// where len=MAX(13, length of auth-plugin-data - 8)
@@ -267,6 +272,24 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
267272

268273
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
269274

275+
connectAttrsBuf := make([]byte, 0, 100)
276+
if mc.flags&clientConnectAttrs != 0 {
277+
clientFlags |= clientConnectAttrs
278+
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("_client_name"))
279+
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte("go-mysql-driver"))
280+
281+
for k, v := range mc.cfg.Attributes {
282+
if k == "_client_name" {
283+
// do not allow overwriting reserved values
284+
continue
285+
}
286+
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(k))
287+
connectAttrsBuf = appendLengthEncodedString(connectAttrsBuf, []byte(v))
288+
}
289+
connectAttrsBuf = appendLengthEncodedString(make([]byte, 0, 100), connectAttrsBuf)
290+
pktLen += len(connectAttrsBuf)
291+
}
292+
270293
// To specify a db name
271294
if n := len(mc.cfg.DBName); n > 0 {
272295
clientFlags |= clientConnectWithDB
@@ -347,6 +370,11 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
347370
// Assume native client during response
348371
pos += copy(data[pos:], "mysql_native_password")
349372
data[pos] = 0x00
373+
pos++
374+
375+
if clientFlags&clientConnectAttrs != 0 {
376+
pos += copy(data[pos:], connectAttrsBuf)
377+
}
350378

351379
// Send Auth packet
352380
return mc.writePacket(data)

utils.go

+6
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,12 @@ func skipLengthEncodedString(b []byte) (int, error) {
560560
return n, io.EOF
561561
}
562562

563+
// encodes a bytes slice with prepended length-encoded size and appends it to the given bytes slice
564+
func appendLengthEncodedString(b []byte, str []byte) []byte {
565+
b = appendLengthEncodedInteger(b, uint64(len(str)))
566+
return append(b, str...)
567+
}
568+
563569
// returns the number read, whether the value is NULL and the number of bytes read
564570
func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
565571
// See issue #349

0 commit comments

Comments
 (0)