diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 2f78c9acf7972..beb402f2eca98 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -1158,7 +1158,8 @@ def to_msgpack(self, path_or_buf=None, encoding='utf-8', **kwargs): **kwargs) def to_sql(self, name, con, flavor=None, schema=None, if_exists='fail', - index=True, index_label=None, chunksize=None, dtype=None): + index=True, index_label=None, chunksize=None, dtype=None, + indexes=None): """ Write records stored in a DataFrame to a SQL database. @@ -1197,7 +1198,7 @@ def to_sql(self, name, con, flavor=None, schema=None, if_exists='fail', from pandas.io import sql sql.to_sql(self, name, con, flavor=flavor, schema=schema, if_exists=if_exists, index=index, index_label=index_label, - chunksize=chunksize, dtype=dtype) + chunksize=chunksize, dtype=dtype, indexes=indexes) def to_pickle(self, path): """ diff --git a/pandas/io/sql.py b/pandas/io/sql.py index 47642c2e2bc28..994b9500ea1a9 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -416,7 +416,8 @@ def read_sql(sql, con, index_col=None, coerce_float=True, params=None, def to_sql(frame, name, con, flavor=None, schema=None, if_exists='fail', - index=True, index_label=None, chunksize=None, dtype=None): + index=True, index_label=None, indexes=None, chunksize=None, + dtype=None): """ Write records stored in a DataFrame to a SQL database. @@ -445,6 +446,10 @@ def to_sql(frame, name, con, flavor=None, schema=None, if_exists='fail', Column label for index column(s). If None is given (default) and `index` is True, then the index names are used. A sequence should be given if the DataFrame uses MultiIndex. + indexes : list of column name(s). Columns names in this list will have + an indexes created for them in the database. + + .. versionadded:: 0.18.2 chunksize : int, default None If not None, then rows will be written in batches of this size at a time. If None, all rows will be written at once. @@ -467,7 +472,7 @@ def to_sql(frame, name, con, flavor=None, schema=None, if_exists='fail', pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index, index_label=index_label, schema=schema, - chunksize=chunksize, dtype=dtype) + chunksize=chunksize, dtype=dtype, indexes=indexes) def has_table(table_name, con, flavor=None, schema=None): @@ -546,12 +551,13 @@ class SQLTable(PandasObject): def __init__(self, name, pandas_sql_engine, frame=None, index=True, if_exists='fail', prefix='pandas', index_label=None, - schema=None, keys=None, dtype=None): + schema=None, keys=None, dtype=None, indexes=None): self.name = name self.pd_sql = pandas_sql_engine self.prefix = prefix self.frame = frame self.index = self._index_name(index, index_label) + self.indexes = indexes self.schema = schema self.if_exists = if_exists self.keys = keys @@ -742,18 +748,37 @@ def _index_name(self, index, index_label): else: return None + def _is_column_indexed(self, label): + # column is explicitly set to be indexed + if self.indexes is not None and label in self.indexes: + return True + + # if df index is also a column it needs an index unless it's + # also a primary key (otherwise there would be two indexes). + # multi-index can use primary key if the left hand side matches. + if self.index is not None and label in self.index: + if self.keys is None: + return True + + col_nr = self.index.index(label) + 1 + if self.keys[:col_nr] != self.index[:col_nr]: + return True + + return False + def _get_column_names_and_types(self, dtype_mapper): column_names_and_types = [] if self.index is not None: for i, idx_label in enumerate(self.index): idx_type = dtype_mapper( self.frame.index.get_level_values(i)) - column_names_and_types.append((idx_label, idx_type, True)) + indexed = self._is_column_indexed(idx_label) + column_names_and_types.append((idx_label, idx_type, indexed)) column_names_and_types += [ (text_type(self.frame.columns[i]), dtype_mapper(self.frame.iloc[:, i]), - False) + self._is_column_indexed(text_type(self.frame.columns[i]))) for i in range(len(self.frame.columns)) ] @@ -1098,7 +1123,8 @@ def read_query(self, sql, index_col=None, coerce_float=True, read_sql = read_query def to_sql(self, frame, name, if_exists='fail', index=True, - index_label=None, schema=None, chunksize=None, dtype=None): + index_label=None, schema=None, chunksize=None, dtype=None, + indexes=None): """ Write records stored in a DataFrame to a SQL database. @@ -1142,7 +1168,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, table = SQLTable(name, self, frame=frame, index=index, if_exists=if_exists, index_label=index_label, - schema=schema, dtype=dtype) + schema=schema, dtype=dtype, indexes=indexes) table.create() table.insert(chunksize) if (not name.isdigit() and not name.islower()): @@ -1456,7 +1482,8 @@ def _fetchall_as_list(self, cur): return result def to_sql(self, frame, name, if_exists='fail', index=True, - index_label=None, schema=None, chunksize=None, dtype=None): + index_label=None, schema=None, chunksize=None, dtype=None, + indexes=None): """ Write records stored in a DataFrame to a SQL database. @@ -1497,7 +1524,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, table = SQLiteTable(name, self, frame=frame, index=index, if_exists=if_exists, index_label=index_label, - dtype=dtype) + dtype=dtype, indexes=indexes) table.create() table.insert(chunksize) diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index 198a4017b5af7..58aa74e219bb5 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -310,6 +310,14 @@ def _load_test3_data(self): self.test_frame3 = DataFrame(data, columns=columns) + def _load_test4_data(self): + n = 10 + colors = np.random.choice(['red', 'green'], size=n) + foods = np.random.choice(['eggs', 'ham'], size=n) + index = pd.MultiIndex.from_arrays([colors, foods], + names=['color', 'food']) + self.test_frame4 = DataFrame(np.random.randn(n, 2), index=index) + def _load_raw_sql(self): self.drop_table('types_test_data') self._get_exec().execute(SQL_STRINGS['create_test_types'][self.flavor]) @@ -513,6 +521,7 @@ def setUp(self): self._load_test1_data() self._load_test2_data() self._load_test3_data() + self._load_test4_data() self._load_raw_sql() def test_read_sql_iris(self): @@ -930,7 +939,7 @@ def test_warning_case_insensitive_table_name(self): def _get_index_columns(self, tbl_name): from sqlalchemy.engine import reflection insp = reflection.Inspector.from_engine(self.conn) - ixs = insp.get_indexes('test_index_saved') + ixs = insp.get_indexes(tbl_name) ixs = [i['column_names'] for i in ixs] return ixs @@ -963,6 +972,66 @@ def test_to_sql_read_sql_with_database_uri(self): tm.assert_frame_equal(test_frame1, test_frame3) tm.assert_frame_equal(test_frame1, test_frame4) + def test_to_sql_column_indexes(self): + temp_frame = DataFrame({'col1': range(4), 'col2': range(4)}) + sql.to_sql(temp_frame, 'test_to_sql_column_indexes', self.conn, + index=False, if_exists='replace', indexes=['col1', 'col2']) + ix_cols = self._get_index_columns('test_to_sql_column_indexes') + self.assertEqual(sorted(ix_cols), [['col1'], ['col2']], + "columns are not correctly indexes") + + def test_sqltable_key_and_multiindex_no_pk(self): + db = sql.SQLDatabase(self.conn) + table = sql.SQLTable('test_sqltable_key_and_multiindex_no_pk', db, + frame=self.test_frame4, index=True) + metadata = table.table.tometadata(table.pd_sql.meta) + indexed_columns = [e.columns.keys() for e in metadata.indexes] + primary_keys = metadata.primary_key.columns.keys() + self.assertListEqual([['color'], ['food']], sorted(indexed_columns), + "Wrong secondary indexes") + self.assertListEqual([], primary_keys, + "There should be no primary keys") + + def test_sqltable_key_and_multiindex_one_pk(self): + db = sql.SQLDatabase(self.conn) + table = sql.SQLTable('test_sqltable_key_and_multiindex_one_pk', db, + frame=self.test_frame4, index=True, + keys=['color']) + metadata = table.table.tometadata(table.pd_sql.meta) + indexed_columns = [e.columns.keys() for e in metadata.indexes] + primary_keys = metadata.primary_key.columns.keys() + self.assertListEqual([['food']], indexed_columns, + "Wrong secondary indexes") + self.assertListEqual(['color'], primary_keys, + "Wrong primary keys") + + def test_sqltable_key_and_multiindex_two_pk(self): + db = sql.SQLDatabase(self.conn) + table = sql.SQLTable('test_sqltable_key_and_multiindex_two_pk', db, + frame=self.test_frame4, index=True, + keys=['color', 'food']) + metadata = table.table.tometadata(table.pd_sql.meta) + indexed_columns = [e.columns.keys() for e in metadata.indexes] + primary_keys = metadata.primary_key.columns.keys() + self.assertListEqual([], indexed_columns, + "There should be no secondary indexes") + self.assertListEqual(['color', 'food'], primary_keys, + "Wrong primary keys") + + def test_sqltable_no_double_key_and_index_index(self): + temp_frame = DataFrame({'col1': range(4), 'col2': range(4)}) + db = sql.SQLDatabase(self.conn) + table = sql.SQLTable('test_sqltable_no_double_key_and_index_index', db, + frame=temp_frame, index=True, index_label='id', + keys=['id'], indexes=['col1', 'col2']) + table_metadata = table.table.tometadata(table.pd_sql.meta) + indexed_columns = [e.columns.keys() for e in table_metadata.indexes] + self.assertNotIn('id', indexed_columns, + "Secondary Index found for primary key") + + self.assertListEqual(['id'], table_metadata.primary_key.columns.keys(), + "Primary key missing from table") + def _make_iris_table_metadata(self): sa = sqlalchemy metadata = sa.MetaData()