diff --git a/CHANGELOG.md b/CHANGELOG.md index af0c10d19..97cf052c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ New Features: - Logging of critical errors is configurable with `SetLogger` +Bugfixes: + + - Allow more than 32 parameters in prepared statements + + ## Version 1.1 (2013-11-02) Changes: diff --git a/driver_test.go b/driver_test.go index 57bac668b..cad40ec61 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1210,6 +1210,30 @@ func TestStmtMultiRows(t *testing.T) { }) } +// Regression test for +// * more than 32 NULL parameters (issue 209) +// * more parameters than fit into the buffer (issue 201) +func TestPreparedManyCols(t *testing.T) { + const numParams = defaultBufSize + runTests(t, dsn, func(dbt *DBTest) { + query := "SELECT ?" + strings.Repeat(",?", numParams-1) + stmt, err := dbt.db.Prepare(query) + if err != nil { + dbt.Fatal(err) + } + defer stmt.Close() + // create more parameters than fit into the buffer + // which will take nil-values + params := make([]interface{}, numParams) + rows, err := stmt.Query(params...) + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows.Close() + }) +} + func TestConcurrent(t *testing.T) { if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled { t.Skip("MYSQL_TEST_CONCURRENT env var not set") diff --git a/packets.go b/packets.go index ff1a6eaba..49aaf1807 100644 --- a/packets.go +++ b/packets.go @@ -750,6 +750,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) } + const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc // Reset packet-sequence @@ -758,7 +759,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { var data []byte if len(args) == 0 { - data = mc.buf.takeBuffer(4 + 1 + 4 + 1 + 4) + data = mc.buf.takeBuffer(minPktLen) } else { data = mc.buf.takeCompleteBuffer() } @@ -787,10 +788,26 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data[13] = 0x00 if len(args) > 0 { - // NULL-bitmap [(len(args)+7)/8 bytes] - nullMask := uint64(0) - - pos := 4 + 1 + 4 + 1 + 4 + ((len(args) + 7) >> 3) + pos := minPktLen + + var nullMask []byte + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { + // buffer has to be extended but we don't know by how much so + // we depend on append after all data with known sizes fit. + // We stop at that because we deal with a lot of columns here + // which makes the required allocation size hard to guess. + tmp := make([]byte, pos+maskLen+typesLen) + copy(tmp[:pos], data[:pos]) + data = tmp + nullMask = data[pos : pos+maskLen] + pos += maskLen + } else { + nullMask = data[pos : pos+maskLen] + for i := 0; i < maskLen; i++ { + nullMask[i] = 0 + } + pos += maskLen + } // newParameterBoundFlag 1 [1 byte] data[pos] = 0x01 @@ -798,23 +815,23 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // type of each parameter [len(args)*2 bytes] paramTypes := data[pos:] - pos += (len(args) << 1) + pos += len(args) * 2 // value of each parameter [n bytes] paramValues := data[pos:pos] valuesCap := cap(paramValues) - for i := range args { + for i, arg := range args { // build NULL-bitmap - if args[i] == nil { - nullMask |= 1 << uint(i) + if arg == nil { + nullMask[i/8] |= 1 << (uint(i) & 7) paramTypes[i+i] = fieldTypeNULL paramTypes[i+i+1] = 0x00 continue } // cache types and values - switch v := args[i].(type) { + switch v := arg.(type) { case int64: paramTypes[i+i] = fieldTypeLongLong paramTypes[i+i+1] = 0x00 @@ -877,7 +894,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } // Handle []byte(nil) as a NULL value - nullMask |= 1 << uint(i) + nullMask[i/8] |= 1 << (uint(i) & 7) paramTypes[i+i] = fieldTypeNULL paramTypes[i+i+1] = 0x00 @@ -913,7 +930,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { paramValues = append(paramValues, val...) default: - return fmt.Errorf("Can't convert type: %T", args[i]) + return fmt.Errorf("Can't convert type: %T", arg) } } @@ -926,11 +943,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { pos += len(paramValues) data = data[:pos] - - // Convert nullMask to bytes - for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ { - data[i+14] = byte(nullMask >> uint(i<<3)) - } } return mc.writePacket(data)