Skip to content

Commit 6eda7f8

Browse files
committed
re-introduce packet pool
1 parent b6c7b5d commit 6eda7f8

File tree

7 files changed

+93
-40
lines changed

7 files changed

+93
-40
lines changed

auth.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,10 +368,11 @@ func (mc *mysqlConn) handleAuthResult(ctx context.Context, oldAuthData []byte, p
368368
return err
369369
}
370370

371-
data, err = mc.readPacket(ctx)
371+
packet, err := mc.readPacket(ctx)
372372
if err != nil {
373373
return err
374374
}
375+
data = packet.data
375376

376377
if data[0] != iAuthMoreData {
377378
return fmt.Errorf("unexpected resp from server for caching_sha2_password, perform full authentication")

connection.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ type mysqlConn struct {
5252
// for context support (Go 1.8+)
5353
closech chan struct{}
5454
closed atomicBool // set when conn is closed, before closech is closed
55-
readRes chan packet // channel for read result
55+
readRes chan *packet // channel for read result
5656
writeReq chan []byte // buffered channel for write packets
5757
writeRes chan writeResult // channel for write result
5858
}

connection_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ func TestPingMarkBadConnection(t *testing.T) {
157157
netConn: nc,
158158
rbuf: newReadBuffer(nc),
159159
maxAllowedPacket: defaultMaxAllowedPacket,
160+
connector: &connector{},
160161
}
161162
ms.startGoroutines()
162163
defer ms.cleanup()
@@ -180,6 +181,7 @@ func TestPingErrInvalidConn(t *testing.T) {
180181
maxAllowedPacket: defaultMaxAllowedPacket,
181182
closech: make(chan struct{}),
182183
cfg: NewConfig(),
184+
connector: &connector{},
183185
}
184186
ms.startGoroutines()
185187
defer ms.cleanup()

connector.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ import (
1616
"os"
1717
"strconv"
1818
"strings"
19+
"sync"
1920
)
2021

2122
type connector struct {
2223
cfg *Config // immutable private copy.
2324
encodedAttributes string // Encoded connection attributes.
25+
packetPool sync.Pool
2426
}
2527

2628
func encodeConnectionAttributes(textAttributes string) string {
@@ -169,6 +171,20 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
169171
return mc, nil
170172
}
171173

174+
func (c *connector) getPacket() *packet {
175+
p := c.packetPool.Get()
176+
if p == nil {
177+
return &packet{}
178+
}
179+
return p.(*packet)
180+
}
181+
182+
func (c *connector) putPacket(p *packet) {
183+
if p != nil && len(p.data) < maxPacketSize {
184+
c.packetPool.Put(p)
185+
}
186+
}
187+
172188
// Driver implements driver.Connector interface.
173189
// Driver returns &MySQLDriver{}.
174190
func (c *connector) Driver() driver.Driver {

packets.go

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ func (p *packet) readFrom(r *readBuffer) {
5151
return
5252
}
5353

54-
p.data = append([]byte(nil), data...) // TODO: reduce allocations
54+
p.data = append(p.data[:0], data...)
5555
}
5656

5757
// Read packet to buffer 'data'
58-
func (mc *mysqlConn) readPacket(ctx context.Context) ([]byte, error) {
59-
var prevData []byte
58+
func (mc *mysqlConn) readPacket(ctx context.Context) (*packet, error) {
59+
var prevData *packet
6060
for {
61-
var pkt packet
61+
var pkt *packet
6262
select {
6363
case pkt = <-mc.readRes:
6464
case <-mc.closech:
@@ -99,12 +99,24 @@ func (mc *mysqlConn) readPacket(ctx context.Context) ([]byte, error) {
9999
return prevData, nil
100100
}
101101

102-
prevData = append(prevData, pkt.data...)
103-
104102
// return data if this was the last packet
105103
if pktLen < maxPacketSize {
104+
// zero allocations for non-split packets
105+
if prevData == nil {
106+
return pkt, nil
107+
}
108+
109+
prevData.data = append(prevData.data, pkt.data...)
110+
mc.connector.putPacket(pkt)
106111
return prevData, nil
107112
}
113+
114+
if prevData != nil {
115+
prevData.data = append(prevData.data, pkt.data...)
116+
mc.connector.putPacket(pkt)
117+
} else {
118+
prevData = pkt
119+
}
108120
}
109121
}
110122

@@ -209,7 +221,7 @@ func (mc *mysqlConn) writePacket(ctx context.Context, data []byte) error {
209221
// Handshake Initialization Packet
210222
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
211223
func (mc *mysqlConn) readHandshakePacket(ctx context.Context) (data []byte, plugin string, err error) {
212-
data, err = mc.readPacket(ctx)
224+
packet, err := mc.readPacket(ctx)
213225
if err != nil {
214226
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
215227
// in connection initialization we don't risk retrying non-idempotent actions.
@@ -218,6 +230,8 @@ func (mc *mysqlConn) readHandshakePacket(ctx context.Context) (data []byte, plug
218230
}
219231
return
220232
}
233+
defer mc.connector.putPacket(packet)
234+
data = packet.data
221235

222236
if data[0] == iERR {
223237
return nil, "", mc.handleErrorPacket(data)
@@ -504,10 +518,11 @@ func (mc *mysqlConn) writeCommandPacketUint32(ctx context.Context, command byte,
504518
******************************************************************************/
505519

506520
func (mc *mysqlConn) readAuthResult(ctx context.Context) ([]byte, string, error) {
507-
data, err := mc.readPacket(ctx)
521+
packet, err := mc.readPacket(ctx)
508522
if err != nil {
509523
return nil, "", err
510524
}
525+
data := packet.data
511526

512527
// packet indicator
513528
switch data[0] {
@@ -540,10 +555,11 @@ func (mc *mysqlConn) readAuthResult(ctx context.Context) ([]byte, string, error)
540555

541556
// Returns error if Packet is not a 'Result OK'-Packet
542557
func (mc *okHandler) readResultOK(ctx context.Context) error {
543-
data, err := mc.conn().readPacket(ctx)
558+
packet, err := mc.conn().readPacket(ctx)
544559
if err != nil {
545560
return err
546561
}
562+
data := packet.data
547563

548564
if data[0] == iOK {
549565
return mc.handleOkPacket(data)
@@ -558,10 +574,12 @@ func (mc *okHandler) readResultSetHeaderPacket(ctx context.Context) (int, error)
558574
mc.result.affectedRows = append(mc.result.affectedRows, 0)
559575
mc.result.insertIds = append(mc.result.insertIds, 0)
560576

561-
data, err := mc.conn().readPacket(ctx)
577+
packet, err := mc.conn().readPacket(ctx)
562578
if err != nil {
563579
return 0, err
564580
}
581+
defer mc.conn().connector.putPacket(packet)
582+
data := packet.data
565583
if err == nil {
566584
switch data[0] {
567585

@@ -704,10 +722,11 @@ func (mc *mysqlConn) readColumns(ctx context.Context, count int) ([]mysqlField,
704722
columns := make([]mysqlField, count)
705723

706724
for i := 0; ; i++ {
707-
data, err := mc.readPacket(ctx)
725+
packet, err := mc.readPacket(ctx)
708726
if err != nil {
709727
return nil, err
710728
}
729+
data := packet.data
711730

712731
// EOF Packet
713732
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
@@ -808,10 +827,13 @@ func (rows *textRows) readRow(dest []driver.Value) error {
808827
return io.EOF
809828
}
810829

811-
data, err := mc.readPacket(ctx)
830+
rows.mc.connector.putPacket(rows.pkt)
831+
packet, err := mc.readPacket(ctx)
832+
rows.pkt = packet
812833
if err != nil {
813834
return err
814835
}
836+
data := packet.data
815837

816838
// EOF Packet
817839
if data[0] == iEOF && len(data) == 5 {
@@ -893,10 +915,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
893915
// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read
894916
func (mc *mysqlConn) readUntilEOF(ctx context.Context) error {
895917
for {
896-
data, err := mc.readPacket(ctx)
918+
packet, err := mc.readPacket(ctx)
897919
if err != nil {
898920
return err
899921
}
922+
data := packet.data
900923

901924
switch data[0] {
902925
case iERR:
@@ -907,6 +930,7 @@ func (mc *mysqlConn) readUntilEOF(ctx context.Context) error {
907930
}
908931
return nil
909932
}
933+
mc.connector.putPacket(packet)
910934
}
911935
}
912936

@@ -917,10 +941,12 @@ func (mc *mysqlConn) readUntilEOF(ctx context.Context) error {
917941
// Prepare Result Packets
918942
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
919943
func (stmt *mysqlStmt) readPrepareResultPacket(ctx context.Context) (uint16, error) {
920-
data, err := stmt.mc.readPacket(ctx)
944+
packet, err := stmt.mc.readPacket(ctx)
921945
if err != nil {
922946
return 0, err
923947
}
948+
defer stmt.mc.connector.putPacket(packet)
949+
data := packet.data
924950
if err == nil {
925951
// packet indicator [1 byte]
926952
if data[0] != iOK {
@@ -1253,10 +1279,14 @@ func (mc *okHandler) discardResults(ctx context.Context) error {
12531279
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
12541280
func (rows *binaryRows) readRow(dest []driver.Value) error {
12551281
ctx := rows.ctx
1256-
data, err := rows.mc.readPacket(ctx)
1282+
1283+
rows.mc.connector.putPacket(rows.pkt)
1284+
packet, err := rows.mc.readPacket(ctx)
1285+
rows.pkt = packet
12571286
if err != nil {
12581287
return err
12591288
}
1289+
data := packet.data
12601290

12611291
// packet indicator [1 byte]
12621292
if data[0] != iOK {
@@ -1432,7 +1462,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
14321462

14331463
func (mc *mysqlConn) startGoroutines() {
14341464
mc.closech = make(chan struct{})
1435-
mc.readRes = make(chan packet)
1465+
mc.readRes = make(chan *packet)
14361466
mc.writeReq = make(chan []byte, 1)
14371467
mc.writeRes = make(chan writeResult)
14381468

@@ -1442,7 +1472,7 @@ func (mc *mysqlConn) startGoroutines() {
14421472

14431473
func (mc *mysqlConn) readLoop() {
14441474
for {
1445-
var pkt packet
1475+
pkt := mc.connector.getPacket()
14461476
mc.muRead.Lock()
14471477
pkt.readFrom(&mc.rbuf)
14481478
mc.muRead.Unlock()

packets_test.go

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ func TestReadPacketSingleByte(t *testing.T) {
4646
conn.Write([]byte{0x01, 0x00, 0x00, 0x00, 0xff})
4747
}()
4848

49-
data, err := mc.readPacket(context.Background())
49+
packet, err := mc.readPacket(context.Background())
5050
if err != nil {
5151
t.Fatal(err)
5252
}
53-
if len(data) != 1 {
54-
t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(data))
53+
if len(packet.data) != 1 {
54+
t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet.data))
5555
}
56-
if data[0] != 0xff {
57-
t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, data[0])
56+
if packet.data[0] != 0xff {
57+
t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet.data[0])
5858
}
5959
}
6060

@@ -124,11 +124,11 @@ func TestReadPacketSplit(t *testing.T) {
124124
}()
125125
// TODO: check read operation count
126126

127-
data, err := mc.readPacket(context.Background())
127+
packet, err := mc.readPacket(context.Background())
128128
if err != nil {
129129
t.Fatal(err)
130130
}
131-
if len(data) != maxPacketSize {
131+
if len(packet.data) != maxPacketSize {
132132
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(data))
133133
}
134134
if data[0] != 0x11 {
@@ -173,18 +173,18 @@ func TestReadPacketSplit(t *testing.T) {
173173
}()
174174
// TODO: check read operation count
175175

176-
data, err := mc.readPacket(context.Background())
176+
packet, err := mc.readPacket(context.Background())
177177
if err != nil {
178178
t.Fatal(err)
179179
}
180-
if len(data) != 2*maxPacketSize {
181-
t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(data))
180+
if len(packet.data) != 2*maxPacketSize {
181+
t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet.data))
182182
}
183-
if data[0] != 0x11 {
184-
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, data[0])
183+
if packet.data[0] != 0x11 {
184+
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet.data[0])
185185
}
186-
if data[2*maxPacketSize-1] != 0x44 {
187-
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, data[2*maxPacketSize-1])
186+
if packet.data[2*maxPacketSize-1] != 0x44 {
187+
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet.data[2*maxPacketSize-1])
188188
}
189189
})
190190

@@ -215,18 +215,18 @@ func TestReadPacketSplit(t *testing.T) {
215215
}()
216216
// TODO: check read operation count
217217

218-
data, err := mc.readPacket(context.Background())
218+
packet, err := mc.readPacket(context.Background())
219219
if err != nil {
220220
t.Fatal(err)
221221
}
222-
if len(data) != maxPacketSize+42 {
223-
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(data))
222+
if len(packet.data) != maxPacketSize+42 {
223+
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet.data))
224224
}
225-
if data[0] != 0x11 {
226-
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, data[0])
225+
if packet.data[0] != 0x11 {
226+
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet.data[0])
227227
}
228-
if data[maxPacketSize+41] != 0x44 {
229-
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, data[maxPacketSize+41])
228+
if packet.data[maxPacketSize+41] != 0x44 {
229+
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet.data[maxPacketSize+41])
230230
}
231231
})
232232
}

rows.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type mysqlRows struct {
2626
mc *mysqlConn
2727
ctx context.Context
2828
rs resultSet
29+
pkt *packet // current read packet
2930
}
3031

3132
type binaryRows struct {
@@ -108,6 +109,9 @@ func (rows *mysqlRows) Close() (err error) {
108109
return err
109110
}
110111

112+
rows.mc.connector.putPacket(rows.pkt)
113+
rows.pkt = nil
114+
111115
// Remove unread packets from stream
112116
if !rows.rs.done {
113117
err = mc.readUntilEOF(ctx)

0 commit comments

Comments
 (0)