diff --git a/AUTHORS b/AUTHORS index fb1478c3b..36a947268 100644 --- a/AUTHORS +++ b/AUTHORS @@ -74,6 +74,7 @@ Maciej Zimnoch Michael Woolnough Nathanial Murphy Nicola Peduzzi +Oliver Bone Olivier Mengué oscarzhao Paul Bonser diff --git a/packets.go b/packets.go index ee05c95a8..c916ff913 100644 --- a/packets.go +++ b/packets.go @@ -43,10 +43,20 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) // check packet sync [8 bit] - if data[3] != mc.sequence { - if data[3] > mc.sequence { + if sequenceID := data[3]; sequenceID != mc.sequence { + if sequenceID > mc.sequence { return nil, ErrPktSyncMul } + + if sequenceID == 0 { + // If the sequence ID is zero then this is not a response to any + // packet that the client has sent. It's likely an error packet + // that the server sent simultaneously. + errLog.Print("received packet with zero sequence ID") + mc.Close() + return nil, ErrInvalidConn + } + return nil, ErrPktSync } mc.sequence++ diff --git a/packets_test.go b/packets_test.go index b61e4dbf7..0a9de0869 100644 --- a/packets_test.go +++ b/packets_test.go @@ -134,9 +134,9 @@ func TestReadPacketWrongSequenceID(t *testing.T) { } // too low sequence id - conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + conn.data = []byte{0x01, 0x00, 0x00, 0x01, 0xff} conn.maxReads = 1 - mc.sequence = 1 + mc.sequence = 2 _, err := mc.readPacket() if err != ErrPktSync { t.Errorf("expected ErrPktSync, got %v", err) @@ -155,6 +155,27 @@ func TestReadPacketWrongSequenceID(t *testing.T) { } } +func TestReadPacketZeroSequenceID(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + closech: make(chan struct{}), + } + + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + conn.maxReads = 1 + mc.sequence = 1 + _, err := mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + + // The connection should not be returned to the pool. + if mc.IsValid() { + t.Errorf("expected connection to no longer be valid") + } +} + func TestReadPacketSplit(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{