diff --git a/AUTHORS b/AUTHORS index 0a22d46e7..a920017a5 100644 --- a/AUTHORS +++ b/AUTHORS @@ -12,6 +12,7 @@ # Individual Persons Aaron Hopkins +Alexander Menzhinsky Arne Hormann Asta Xie Carlos Nieto diff --git a/buffer.go b/buffer.go index 2001feacd..5be82bc7f 100644 --- a/buffer.go +++ b/buffer.go @@ -109,39 +109,39 @@ func (b *buffer) readNext(need int) ([]byte, error) { // If possible, a slice from the existing buffer is returned. // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. -func (b *buffer) takeBuffer(length int) []byte { +func (b *buffer) takeBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrUnreadTxRows } // test (cheap) general case first if length <= defaultBufSize || length <= cap(b.buf) { - return b.buf[:length] + return b.buf[:length], nil } if length < maxPacketSize { b.buf = make([]byte, length) - return b.buf + return b.buf, nil } - return make([]byte, length) + return make([]byte, length), nil } // shortcut which can be used if the requested buffer is guaranteed to be // smaller than defaultBufSize // Only one buffer (total) can be used at a time. -func (b *buffer) takeSmallBuffer(length int) []byte { - if b.length == 0 { - return b.buf[:length] +func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { + if b.length > 0 { + return nil, ErrUnreadTxRows } - return nil + return b.buf[:length], nil } // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. // Only one buffer (total) can be used at a time. -func (b *buffer) takeCompleteBuffer() []byte { - if b.length == 0 { - return b.buf +func (b *buffer) takeCompleteBuffer() ([]byte, error) { + if b.length > 0 { + return nil, ErrUnreadTxRows } - return nil + return b.buf, nil } diff --git a/connection.go b/connection.go index cdce3e30f..d329a9dee 100644 --- a/connection.go +++ b/connection.go @@ -141,11 +141,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - buf := mc.buf.takeCompleteBuffer() - if buf == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return "", driver.ErrBadConn + buf, err := mc.buf.takeCompleteBuffer() + if err != nil { + return "", err } buf = buf[:0] argPos := 0 diff --git a/driver_test.go b/driver_test.go index 6ca5434a9..e2502d47e 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1991,3 +1991,31 @@ func TestRejectReadOnly(t *testing.T) { dbt.mustExec("DROP TABLE test") }) } + +func TestUnclosedRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + tx, err := dbt.db.Begin() + if err != nil { + dbt.Fatal(err) + } + + rows, err := tx.Query("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + + // here's common mistake: rows are closed only + // when current func exits keeping the rows buffer + // busy for the following request. + defer rows.Close() + + if !rows.Next() { + dbt.Fatal("no rows after `SELECT 1`") + } + + _, err = tx.Query("SELECT 2") + if err != ErrUnreadTxRows { + dbt.Errorf("got %v, want %v", err, ErrUnreadTxRows) + } + }) +} diff --git a/errors.go b/errors.go index 857854e14..102e4b58d 100644 --- a/errors.go +++ b/errors.go @@ -30,7 +30,7 @@ var ( ErrPktSync = errors.New("commands out of sync. You can't run this command now") ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") - ErrBusyBuffer = errors.New("busy buffer") + ErrUnreadTxRows = errors.New("rows buffer is busy. Try to read out or close previous rows") ) var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) diff --git a/packets.go b/packets.go index 303405a17..440c15841 100644 --- a/packets.go +++ b/packets.go @@ -259,11 +259,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { } // Calculate packet length and get buffer with that size - data := mc.buf.takeSmallBuffer(pktLen + 4) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(pktLen + 4) + if err != nil { + return err } // ClientFlags [32 bit] @@ -345,11 +343,9 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(4 + pktLen) + if err != nil { + return err } // Add the scrambled password [null terminated string] @@ -364,11 +360,9 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { func (mc *mysqlConn) writeClearAuthPacket() error { // Calculate the packet length and add a tailing 0 pktLen := len(mc.cfg.Passwd) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(4 + pktLen) + if err != nil { + return err } // Add the clear password [null terminated string] @@ -385,11 +379,9 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { // Calculate the packet length and add a tailing 0 pktLen := len(scrambleBuff) - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(4 + pktLen) + if err != nil { + return err } // Add the scramble @@ -406,11 +398,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + return err } // Add command byte @@ -425,11 +415,9 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.sequence = 0 pktLen := 1 + len(arg) - data := mc.buf.takeBuffer(pktLen + 4) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeBuffer(pktLen + 4) + if err != nil { + return err } // Add command byte @@ -446,11 +434,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1 + 4) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) + if err != nil { + return err } // Add command byte @@ -907,16 +893,15 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { mc.sequence = 0 var data []byte + var err error if len(args) == 0 { - data = mc.buf.takeBuffer(minPktLen) + data, err = mc.buf.takeBuffer(minPktLen) } else { - data = mc.buf.takeCompleteBuffer() + data, err = mc.buf.takeCompleteBuffer() } - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + if err != nil { + return err } // command [1 byte]