From 451e17f5d71ab772acee3601e076cef9dba8785d Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 31 Oct 2023 02:41:23 +0800 Subject: [PATCH 1/5] Add support for length encoded connection attributes --- connector.go | 10 +++------- connector_test.go | 7 ++----- driver.go | 9 +++------ packets.go | 16 +++++++--------- packets_test.go | 5 +---- 5 files changed, 16 insertions(+), 31 deletions(-) diff --git a/connector.go b/connector.go index ba3be71e7..c4ea5bc25 100644 --- a/connector.go +++ b/connector.go @@ -11,7 +11,6 @@ package mysql import ( "context" "database/sql/driver" - "fmt" "net" "os" "strconv" @@ -24,7 +23,7 @@ type connector struct { } func encodeConnectionAttributes(textAttributes string) string { - connAttrsBuf := make([]byte, 0, 251) + connAttrsBuf := make([]byte, 0) // default connection attributes connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName) @@ -49,15 +48,12 @@ func encodeConnectionAttributes(textAttributes string) string { return string(connAttrsBuf) } -func newConnector(cfg *Config) (*connector, error) { +func newConnector(cfg *Config) *connector { encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes) - if len(encodedAttributes) > 250 { - return nil, fmt.Errorf("connection attributes are longer than 250 bytes: %dbytes (%q)", len(encodedAttributes), cfg.ConnectionAttributes) - } return &connector{ cfg: cfg, encodedAttributes: encodedAttributes, - }, nil + } } // Connect implements driver.Connector interface. diff --git a/connector_test.go b/connector_test.go index bedb44ce2..82d8c5989 100644 --- a/connector_test.go +++ b/connector_test.go @@ -8,16 +8,13 @@ import ( ) func TestConnectorReturnsTimeout(t *testing.T) { - connector, err := newConnector(&Config{ + connector := newConnector(&Config{ Net: "tcp", Addr: "1.1.1.1:1234", Timeout: 10 * time.Millisecond, }) - if err != nil { - t.Fatal(err) - } - _, err = connector.Connect(context.Background()) + _, err := connector.Connect(context.Background()) if err == nil { t.Fatal("error expected") } diff --git a/driver.go b/driver.go index 45528b920..105316b81 100644 --- a/driver.go +++ b/driver.go @@ -83,10 +83,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { if err != nil { return nil, err } - c, err := newConnector(cfg) - if err != nil { - return nil, err - } + c := newConnector(cfg) return c.Connect(context.Background()) } @@ -108,7 +105,7 @@ func NewConnector(cfg *Config) (driver.Connector, error) { if err := cfg.normalize(); err != nil { return nil, err } - return newConnector(cfg) + return newConnector(cfg), nil } // OpenConnector implements driver.DriverContext. @@ -117,5 +114,5 @@ func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { if err != nil { return nil, err } - return newConnector(cfg) + return newConnector(cfg), nil } diff --git a/packets.go b/packets.go index 0127232ee..49e6bb058 100644 --- a/packets.go +++ b/packets.go @@ -292,15 +292,14 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pktLen += n + 1 } - // 1 byte to store length of all key-values - // NOTE: Actually, this is length encoded integer. - // But we support only len(connAttrBuf) < 251 for now because takeSmallBuffer - // doesn't support buffer size more than 4096 bytes. - // TODO(methane): Rewrite buffer management. - pktLen += 1 + len(mc.connector.encodedAttributes) + // encode length of the connection attributes + var connAttrsLEIBuf [9]byte + connAttrsLen := len(mc.connector.encodedAttributes) + connAttrsLEI := appendLengthEncodedInteger(connAttrsLEIBuf[:0], uint64(connAttrsLen)) + pktLen += len(connAttrsLEI) + len(mc.connector.encodedAttributes) // Calculate packet length and get buffer with that size - data, err := mc.buf.takeSmallBuffer(pktLen + 4) + data, err := mc.buf.takeBuffer(pktLen + 4) if err != nil { // cannot take the buffer. Something must be wrong with the connection mc.cfg.Logger.Print(err) @@ -380,8 +379,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string pos++ // Connection Attributes - data[pos] = byte(len(mc.connector.encodedAttributes)) - pos++ + pos += copy(data[pos:], connAttrsLEI) pos += copy(data[pos:], []byte(mc.connector.encodedAttributes)) // Send Auth packet diff --git a/packets_test.go b/packets_test.go index e86ec5848..fa4683eab 100644 --- a/packets_test.go +++ b/packets_test.go @@ -96,10 +96,7 @@ var _ net.Conn = new(mockConn) func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { conn := new(mockConn) - connector, err := newConnector(NewConfig()) - if err != nil { - panic(err) - } + connector := newConnector(NewConfig()) mc := &mysqlConn{ buf: newBuffer(conn), cfg: connector.cfg, From c422dd11b87bf2cc78db4a1747afd0aeee68f657 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 31 Oct 2023 03:41:15 +0800 Subject: [PATCH 2/5] Add default connection attribute '_server_host' --- connector.go | 9 ++++++--- const.go | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/connector.go b/connector.go index c4ea5bc25..6e8d65f6d 100644 --- a/connector.go +++ b/connector.go @@ -22,7 +22,7 @@ type connector struct { encodedAttributes string // Encoded connection attributes. } -func encodeConnectionAttributes(textAttributes string) string { +func encodeConnectionAttributes(cfg *Config) string { connAttrsBuf := make([]byte, 0) // default connection attributes @@ -34,9 +34,12 @@ func encodeConnectionAttributes(textAttributes string) string { connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue) connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid) connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid())) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost) + serverHost, _, _ := net.SplitHostPort(cfg.Addr) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost) // user-defined connection attributes - for _, connAttr := range strings.Split(textAttributes, ",") { + for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") { k, v, found := strings.Cut(connAttr, ":") if !found { continue @@ -49,7 +52,7 @@ func encodeConnectionAttributes(textAttributes string) string { } func newConnector(cfg *Config) *connector { - encodedAttributes := encodeConnectionAttributes(cfg.ConnectionAttributes) + encodedAttributes := encodeConnectionAttributes(cfg) return &connector{ cfg: cfg, encodedAttributes: encodedAttributes, diff --git a/const.go b/const.go index 0f2621a6f..22526e031 100644 --- a/const.go +++ b/const.go @@ -26,6 +26,7 @@ const ( connAttrPlatform = "_platform" connAttrPlatformValue = runtime.GOARCH connAttrPid = "_pid" + connAttrServerHost = "_server_host" ) // MySQL constants documentation: From 4ec6710652a9a9fea29e2d26979e07c17626256b Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 31 Oct 2023 16:16:21 +0800 Subject: [PATCH 3/5] Update test TestConnectionAttributes --- driver_test.go | 64 +++++++++++++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/driver_test.go b/driver_test.go index f256011a7..39aed8355 100644 --- a/driver_test.go +++ b/driver_test.go @@ -24,6 +24,7 @@ import ( "os" "reflect" "runtime" + "strconv" "strings" "sync" "sync/atomic" @@ -3377,11 +3378,31 @@ func TestConnectionAttributes(t *testing.T) { t.Skipf("MySQL server not running on %s", netAddr) } - attr1 := "attr1" - value1 := "value1" - attr2 := "foo" - value2 := "boo" - dsn += fmt.Sprintf("&connectionAttributes=%s:%s,%s:%s", attr1, value1, attr2, value2) + defaultAttrs := []string{ + connAttrClientName, + connAttrOS, + connAttrPlatform, + connAttrPid, + connAttrServerHost, + } + host, _, _ := net.SplitHostPort(addr) + defaultAttrValues := []string{ + connAttrClientNameValue, + connAttrOSValue, + connAttrPlatformValue, + strconv.Itoa(os.Getpid()), + host, + } + + customAttrs := []string{"attr1", "attr2"} + customAttrValues := []string{"foo", "bar"} + + customAttrStrs := make([]string, len(customAttrs)) + for i := range customAttrs { + customAttrStrs[i] = fmt.Sprintf("%s:%s", customAttrs[i], customAttrValues[i]) + } + + dsn += fmt.Sprintf("&connectionAttributes=%s", strings.Join(customAttrStrs, ",")) var db *sql.DB if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { @@ -3394,27 +3415,22 @@ func TestConnectionAttributes(t *testing.T) { dbt := &DBTest{t, db} - var attrValue string - queryString := "SELECT ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID() and ATTR_NAME = ?" - rows := dbt.mustQuery(queryString, connAttrClientName) - if rows.Next() { - rows.Scan(&attrValue) - if attrValue != connAttrClientNameValue { - dbt.Errorf("expected %q, got %q", connAttrClientNameValue, attrValue) - } - } else { - dbt.Errorf("no data") + queryString := "SELECT ATTR_NAME, ATTR_VALUE FROM performance_schema.session_account_connect_attrs WHERE PROCESSLIST_ID = CONNECTION_ID()" + rows := dbt.mustQuery(queryString) + defer rows.Close() + + rowsMap := make(map[string]string) + for rows.Next() { + var attrName, attrValue string + rows.Scan(&attrName, &attrValue) + rowsMap[attrName] = attrValue } - rows.Close() - rows = dbt.mustQuery(queryString, attr2) - if rows.Next() { - rows.Scan(&attrValue) - if attrValue != value2 { - dbt.Errorf("expected %q, got %q", value2, attrValue) + connAttrs := append(append([]string{}, defaultAttrs...), customAttrs...) + expectedAttrValues := append(append([]string{}, defaultAttrValues...), customAttrValues...) + for i := range connAttrs { + if gotValue := rowsMap[connAttrs[i]]; gotValue != expectedAttrValues[i] { + dbt.Errorf("expected %s, got %s", expectedAttrValues[i], gotValue) } - } else { - dbt.Errorf("no data") } - rows.Close() } From 61915a60286310dc82a82b1c50d30505175310e1 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 31 Oct 2023 16:18:19 +0800 Subject: [PATCH 4/5] Update AUTHORS --- AUTHORS | 2 ++ 1 file changed, 2 insertions(+) diff --git a/AUTHORS b/AUTHORS index dec27daca..0491c4c64 100644 --- a/AUTHORS +++ b/AUTHORS @@ -49,6 +49,7 @@ INADA Naoki Jacek Szwec James Harr Janek Vedock +Jason Ng Jean-Yves Pellé Jeff Hodges Jeffrey Charles @@ -128,6 +129,7 @@ Keybase Inc. Multiplay Ltd. Percona LLC Pivotal Inc. +Shattered Silicon Ltd. Stripe Inc. Zendesk Inc. Dolthub Inc. From b20004ae28c8a524de0d7522cd47f5b52442bb90 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 14 Nov 2023 21:54:35 +0800 Subject: [PATCH 5/5] Don't send '_server_host' attribute if unix socket is used --- connector.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/connector.go b/connector.go index 6e8d65f6d..3cef7963f 100644 --- a/connector.go +++ b/connector.go @@ -34,9 +34,11 @@ func encodeConnectionAttributes(cfg *Config) string { connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue) connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid) connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid())) - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost) serverHost, _, _ := net.SplitHostPort(cfg.Addr) - connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost) + if serverHost != "" { + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost) + connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost) + } // user-defined connection attributes for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") {