Skip to content

return sql.NullTime if it available #1145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 108 additions & 114 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2758,13 +2758,13 @@ func TestRowsColumnTypes(t *testing.T) {
nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false}
nf0 := sql.NullFloat64{Float64: 0.0, Valid: true}
nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true}
nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true}
nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true}
nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true}
nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true}
nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true}
nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true}
ndNULL := NullTime{Time: time.Time{}, Valid: false}
nt0 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true}
nt1 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true}
nt2 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true}
nt6 := nullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true}
nd1 := nullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true}
nd2 := nullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true}
ndNULL := nullTime{Time: time.Time{}, Valid: false}
rbNULL := sql.RawBytes(nil)
rb0 := sql.RawBytes("0")
rb42 := sql.RawBytes("42")
Expand Down Expand Up @@ -2844,131 +2844,125 @@ func TestRowsColumnTypes(t *testing.T) {
values2 = values2[:len(values2)-2]
values3 = values3[:len(values3)-2]

dsns := []string{
dsn + "&parseTime=true",
dsn + "&parseTime=false",
}
for _, testdsn := range dsns {
runTests(t, testdsn, func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (" + schema + ")")
dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")")
runTests(t, dsn+"&parseTime=true", func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (" + schema + ")")
dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")")

rows, err := dbt.db.Query("SELECT * FROM test")
if err != nil {
t.Fatalf("Query: %v", err)
}
rows, err := dbt.db.Query("SELECT * FROM test")
if err != nil {
t.Fatalf("Query: %v", err)
}

tt, err := rows.ColumnTypes()
if err != nil {
t.Fatalf("ColumnTypes: %v", err)
}
tt, err := rows.ColumnTypes()
if err != nil {
t.Fatalf("ColumnTypes: %v", err)
}

if len(tt) != len(columns) {
t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt))
}
if len(tt) != len(columns) {
t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt))
}

types := make([]reflect.Type, len(tt))
for i, tp := range tt {
column := columns[i]
types := make([]reflect.Type, len(tt))
for i, tp := range tt {
column := columns[i]

// Name
name := tp.Name()
if name != column.name {
t.Errorf("column name mismatch %s != %s", name, column.name)
continue
}
// Name
name := tp.Name()
if name != column.name {
t.Errorf("column name mismatch %s != %s", name, column.name)
continue
}

// DatabaseTypeName
databaseTypeName := tp.DatabaseTypeName()
if databaseTypeName != column.databaseTypeName {
t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName)
continue
}
// DatabaseTypeName
databaseTypeName := tp.DatabaseTypeName()
if databaseTypeName != column.databaseTypeName {
t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName)
continue
}

// ScanType
scanType := tp.ScanType()
if scanType != column.scanType {
if scanType == nil {
t.Errorf("scantype is null for column %q", name)
} else {
t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name())
}
continue
// ScanType
scanType := tp.ScanType()
if scanType != column.scanType {
if scanType == nil {
t.Errorf("scantype is null for column %q", name)
} else {
t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name())
}
types[i] = scanType

// Nullable
nullable, ok := tp.Nullable()
continue
}
types[i] = scanType

// Nullable
nullable, ok := tp.Nullable()
if !ok {
t.Errorf("nullable not ok %q", name)
continue
}
if nullable != column.nullable {
t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable)
}

// Length
// length, ok := tp.Length()
// if length != column.length {
// if !ok {
// t.Errorf("length not ok for column %q", name)
// } else {
// t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length)
// }
// continue
// }

// Precision and Scale
precision, scale, ok := tp.DecimalSize()
if precision != column.precision {
if !ok {
t.Errorf("nullable not ok %q", name)
continue
}
if nullable != column.nullable {
t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable)
t.Errorf("precision not ok for column %q", name)
} else {
t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision)
}

// Length
// length, ok := tp.Length()
// if length != column.length {
// if !ok {
// t.Errorf("length not ok for column %q", name)
// } else {
// t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length)
// }
// continue
// }

// Precision and Scale
precision, scale, ok := tp.DecimalSize()
if precision != column.precision {
if !ok {
t.Errorf("precision not ok for column %q", name)
} else {
t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision)
}
continue
}
if scale != column.scale {
if !ok {
t.Errorf("scale not ok for column %q", name)
} else {
t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale)
}
continue
continue
}
if scale != column.scale {
if !ok {
t.Errorf("scale not ok for column %q", name)
} else {
t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale)
}
continue
}
}

values := make([]interface{}, len(tt))
for i := range values {
values[i] = reflect.New(types[i]).Interface()
values := make([]interface{}, len(tt))
for i := range values {
values[i] = reflect.New(types[i]).Interface()
}
i := 0
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
t.Fatalf("failed to scan values in %v", err)
}
i := 0
for rows.Next() {
err = rows.Scan(values...)
if err != nil {
t.Fatalf("failed to scan values in %v", err)
}
for j := range values {
value := reflect.ValueOf(values[j]).Elem().Interface()
if !reflect.DeepEqual(value, columns[j].valuesOut[i]) {
if columns[j].scanType == scanTypeRawBytes {
t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes)))
} else {
t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i])
}
for j := range values {
value := reflect.ValueOf(values[j]).Elem().Interface()
if !reflect.DeepEqual(value, columns[j].valuesOut[i]) {
if columns[j].scanType == scanTypeRawBytes {
t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes)))
} else {
t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i])
}
}
i++
}
if i != 3 {
t.Errorf("expected 3 rows, got %d", i)
}
i++
}
if i != 3 {
t.Errorf("expected 3 rows, got %d", i)
}

if err := rows.Close(); err != nil {
t.Errorf("error closing rows: %s", err)
}
})
}
if err := rows.Close(); err != nil {
t.Errorf("error closing rows: %s", err)
}
})
}

func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ var (
scanTypeInt64 = reflect.TypeOf(int64(0))
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
scanTypeNullTime = reflect.TypeOf(NullTime{})
scanTypeNullTime = reflect.TypeOf(nullTime{})
scanTypeUint8 = reflect.TypeOf(uint8(0))
scanTypeUint16 = reflect.TypeOf(uint16(0))
scanTypeUint32 = reflect.TypeOf(uint32(0))
Expand Down
5 changes: 5 additions & 0 deletions nulltime_go113.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ import (
//
// This NullTime implementation is not driver-specific
type NullTime sql.NullTime

// for internal use.
// the mysql package uses sql.NullTime if it is available.
// if not, the package uses mysql.NullTime.
type nullTime = sql.NullTime // sql.NullTime is available
5 changes: 5 additions & 0 deletions nulltime_legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}

// for internal use.
// the mysql package uses sql.NullTime if it is available.
// if not, the package uses mysql.NullTime.
type nullTime = NullTime // sql.NullTime is not available