Skip to content

add support for specifying secondary indexes with to_sql #12904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,8 @@ def to_msgpack(self, path_or_buf=None, encoding='utf-8', **kwargs):
**kwargs)

def to_sql(self, name, con, flavor=None, schema=None, if_exists='fail',
index=True, index_label=None, chunksize=None, dtype=None):
index=True, index_label=None, chunksize=None, dtype=None,
indexes=None):
"""
Write records stored in a DataFrame to a SQL database.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the params documented here?


Expand Down Expand Up @@ -1197,7 +1198,7 @@ def to_sql(self, name, con, flavor=None, schema=None, if_exists='fail',
from pandas.io import sql
sql.to_sql(self, name, con, flavor=flavor, schema=schema,
if_exists=if_exists, index=index, index_label=index_label,
chunksize=chunksize, dtype=dtype)
chunksize=chunksize, dtype=dtype, indexes=indexes)

def to_pickle(self, path):
"""
Expand Down
45 changes: 36 additions & 9 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,8 @@ def read_sql(sql, con, index_col=None, coerce_float=True, params=None,


def to_sql(frame, name, con, flavor=None, schema=None, if_exists='fail',
index=True, index_label=None, chunksize=None, dtype=None):
index=True, index_label=None, indexes=None, chunksize=None,
dtype=None):
"""
Write records stored in a DataFrame to a SQL database.

Expand Down Expand Up @@ -445,6 +446,10 @@ def to_sql(frame, name, con, flavor=None, schema=None, if_exists='fail',
Column label for index column(s). If None is given (default) and
`index` is True, then the index names are used.
A sequence should be given if the DataFrame uses MultiIndex.
indexes : list of column name(s). Columns names in this list will have
an indexes created for them in the database.

.. versionadded:: 0.18.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.20.0

chunksize : int, default None
If not None, then rows will be written in batches of this size at a
time. If None, all rows will be written at once.
Expand All @@ -467,7 +472,7 @@ def to_sql(frame, name, con, flavor=None, schema=None, if_exists='fail',

pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index,
index_label=index_label, schema=schema,
chunksize=chunksize, dtype=dtype)
chunksize=chunksize, dtype=dtype, indexes=indexes)


def has_table(table_name, con, flavor=None, schema=None):
Expand Down Expand Up @@ -546,12 +551,13 @@ class SQLTable(PandasObject):

def __init__(self, name, pandas_sql_engine, frame=None, index=True,
if_exists='fail', prefix='pandas', index_label=None,
schema=None, keys=None, dtype=None):
schema=None, keys=None, dtype=None, indexes=None):
self.name = name
self.pd_sql = pandas_sql_engine
self.prefix = prefix
self.frame = frame
self.index = self._index_name(index, index_label)
self.indexes = indexes
self.schema = schema
self.if_exists = if_exists
self.keys = keys
Expand Down Expand Up @@ -742,18 +748,37 @@ def _index_name(self, index, index_label):
else:
return None

def _is_column_indexed(self, label):
# column is explicitly set to be indexed
if self.indexes is not None and label in self.indexes:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put a comment explaining this logic

return True

# if df index is also a column it needs an index unless it's
# also a primary key (otherwise there would be two indexes).
# multi-index can use primary key if the left hand side matches.
if self.index is not None and label in self.index:
if self.keys is None:
return True

col_nr = self.index.index(label) + 1
if self.keys[:col_nr] != self.index[:col_nr]:
return True

return False

def _get_column_names_and_types(self, dtype_mapper):
column_names_and_types = []
if self.index is not None:
for i, idx_label in enumerate(self.index):
idx_type = dtype_mapper(
self.frame.index.get_level_values(i))
column_names_and_types.append((idx_label, idx_type, True))
indexed = self._is_column_indexed(idx_label)
column_names_and_types.append((idx_label, idx_type, indexed))

column_names_and_types += [
(text_type(self.frame.columns[i]),
dtype_mapper(self.frame.iloc[:, i]),
False)
self._is_column_indexed(text_type(self.frame.columns[i])))
for i in range(len(self.frame.columns))
]

Expand Down Expand Up @@ -1098,7 +1123,8 @@ def read_query(self, sql, index_col=None, coerce_float=True,
read_sql = read_query

def to_sql(self, frame, name, if_exists='fail', index=True,
index_label=None, schema=None, chunksize=None, dtype=None):
index_label=None, schema=None, chunksize=None, dtype=None,
indexes=None):
"""
Write records stored in a DataFrame to a SQL database.

Expand Down Expand Up @@ -1142,7 +1168,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,

table = SQLTable(name, self, frame=frame, index=index,
if_exists=if_exists, index_label=index_label,
schema=schema, dtype=dtype)
schema=schema, dtype=dtype, indexes=indexes)
table.create()
table.insert(chunksize)
if (not name.isdigit() and not name.islower()):
Expand Down Expand Up @@ -1456,7 +1482,8 @@ def _fetchall_as_list(self, cur):
return result

def to_sql(self, frame, name, if_exists='fail', index=True,
index_label=None, schema=None, chunksize=None, dtype=None):
index_label=None, schema=None, chunksize=None, dtype=None,
indexes=None):
"""
Write records stored in a DataFrame to a SQL database.

Expand Down Expand Up @@ -1497,7 +1524,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True,

table = SQLiteTable(name, self, frame=frame, index=index,
if_exists=if_exists, index_label=index_label,
dtype=dtype)
dtype=dtype, indexes=indexes)
table.create()
table.insert(chunksize)

Expand Down
71 changes: 70 additions & 1 deletion pandas/io/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,14 @@ def _load_test3_data(self):

self.test_frame3 = DataFrame(data, columns=columns)

def _load_test4_data(self):
n = 10
colors = np.random.choice(['red', 'green'], size=n)
foods = np.random.choice(['eggs', 'ham'], size=n)
index = pd.MultiIndex.from_arrays([colors, foods],
names=['color', 'food'])
self.test_frame4 = DataFrame(np.random.randn(n, 2), index=index)

def _load_raw_sql(self):
self.drop_table('types_test_data')
self._get_exec().execute(SQL_STRINGS['create_test_types'][self.flavor])
Expand Down Expand Up @@ -513,6 +521,7 @@ def setUp(self):
self._load_test1_data()
self._load_test2_data()
self._load_test3_data()
self._load_test4_data()
self._load_raw_sql()

def test_read_sql_iris(self):
Expand Down Expand Up @@ -930,7 +939,7 @@ def test_warning_case_insensitive_table_name(self):
def _get_index_columns(self, tbl_name):
from sqlalchemy.engine import reflection
insp = reflection.Inspector.from_engine(self.conn)
ixs = insp.get_indexes('test_index_saved')
ixs = insp.get_indexes(tbl_name)
ixs = [i['column_names'] for i in ixs]
return ixs

Expand Down Expand Up @@ -963,6 +972,66 @@ def test_to_sql_read_sql_with_database_uri(self):
tm.assert_frame_equal(test_frame1, test_frame3)
tm.assert_frame_equal(test_frame1, test_frame4)

def test_to_sql_column_indexes(self):
temp_frame = DataFrame({'col1': range(4), 'col2': range(4)})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the issue number as a comment

sql.to_sql(temp_frame, 'test_to_sql_column_indexes', self.conn,
index=False, if_exists='replace', indexes=['col1', 'col2'])
ix_cols = self._get_index_columns('test_to_sql_column_indexes')
self.assertEqual(sorted(ix_cols), [['col1'], ['col2']],
"columns are not correctly indexes")

def test_sqltable_key_and_multiindex_no_pk(self):
db = sql.SQLDatabase(self.conn)
table = sql.SQLTable('test_sqltable_key_and_multiindex_no_pk', db,
frame=self.test_frame4, index=True)
metadata = table.table.tometadata(table.pd_sql.meta)
indexed_columns = [e.columns.keys() for e in metadata.indexes]
primary_keys = metadata.primary_key.columns.keys()
self.assertListEqual([['color'], ['food']], sorted(indexed_columns),
"Wrong secondary indexes")
self.assertListEqual([], primary_keys,
"There should be no primary keys")

def test_sqltable_key_and_multiindex_one_pk(self):
db = sql.SQLDatabase(self.conn)
table = sql.SQLTable('test_sqltable_key_and_multiindex_one_pk', db,
frame=self.test_frame4, index=True,
keys=['color'])
metadata = table.table.tometadata(table.pd_sql.meta)
indexed_columns = [e.columns.keys() for e in metadata.indexes]
primary_keys = metadata.primary_key.columns.keys()
self.assertListEqual([['food']], indexed_columns,
"Wrong secondary indexes")
self.assertListEqual(['color'], primary_keys,
"Wrong primary keys")

def test_sqltable_key_and_multiindex_two_pk(self):
db = sql.SQLDatabase(self.conn)
table = sql.SQLTable('test_sqltable_key_and_multiindex_two_pk', db,
frame=self.test_frame4, index=True,
keys=['color', 'food'])
metadata = table.table.tometadata(table.pd_sql.meta)
indexed_columns = [e.columns.keys() for e in metadata.indexes]
primary_keys = metadata.primary_key.columns.keys()
self.assertListEqual([], indexed_columns,
"There should be no secondary indexes")
self.assertListEqual(['color', 'food'], primary_keys,
"Wrong primary keys")

def test_sqltable_no_double_key_and_index_index(self):
temp_frame = DataFrame({'col1': range(4), 'col2': range(4)})
db = sql.SQLDatabase(self.conn)
table = sql.SQLTable('test_sqltable_no_double_key_and_index_index', db,
frame=temp_frame, index=True, index_label='id',
keys=['id'], indexes=['col1', 'col2'])
table_metadata = table.table.tometadata(table.pd_sql.meta)
indexed_columns = [e.columns.keys() for e in table_metadata.indexes]
self.assertNotIn('id', indexed_columns,
"Secondary Index found for primary key")

self.assertListEqual(['id'], table_metadata.primary_key.columns.keys(),
"Primary key missing from table")

def _make_iris_table_metadata(self):
sa = sqlalchemy
metadata = sa.MetaData()
Expand Down