-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
add support for specifying secondary indexes with to_sql #12904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -416,7 +416,8 @@ def read_sql(sql, con, index_col=None, coerce_float=True, params=None, | |
|
||
|
||
def to_sql(frame, name, con, flavor=None, schema=None, if_exists='fail', | ||
index=True, index_label=None, chunksize=None, dtype=None): | ||
index=True, index_label=None, indexes=None, chunksize=None, | ||
dtype=None): | ||
""" | ||
Write records stored in a DataFrame to a SQL database. | ||
|
||
|
@@ -445,6 +446,10 @@ def to_sql(frame, name, con, flavor=None, schema=None, if_exists='fail', | |
Column label for index column(s). If None is given (default) and | ||
`index` is True, then the index names are used. | ||
A sequence should be given if the DataFrame uses MultiIndex. | ||
indexes : list of column name(s). Columns names in this list will have | ||
an indexes created for them in the database. | ||
|
||
.. versionadded:: 0.18.2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 0.20.0 |
||
chunksize : int, default None | ||
If not None, then rows will be written in batches of this size at a | ||
time. If None, all rows will be written at once. | ||
|
@@ -467,7 +472,7 @@ def to_sql(frame, name, con, flavor=None, schema=None, if_exists='fail', | |
|
||
pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index, | ||
index_label=index_label, schema=schema, | ||
chunksize=chunksize, dtype=dtype) | ||
chunksize=chunksize, dtype=dtype, indexes=indexes) | ||
|
||
|
||
def has_table(table_name, con, flavor=None, schema=None): | ||
|
@@ -546,12 +551,13 @@ class SQLTable(PandasObject): | |
|
||
def __init__(self, name, pandas_sql_engine, frame=None, index=True, | ||
if_exists='fail', prefix='pandas', index_label=None, | ||
schema=None, keys=None, dtype=None): | ||
schema=None, keys=None, dtype=None, indexes=None): | ||
self.name = name | ||
self.pd_sql = pandas_sql_engine | ||
self.prefix = prefix | ||
self.frame = frame | ||
self.index = self._index_name(index, index_label) | ||
self.indexes = indexes | ||
self.schema = schema | ||
self.if_exists = if_exists | ||
self.keys = keys | ||
|
@@ -742,18 +748,37 @@ def _index_name(self, index, index_label): | |
else: | ||
return None | ||
|
||
def _is_column_indexed(self, label): | ||
# column is explicitly set to be indexed | ||
if self.indexes is not None and label in self.indexes: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. put a comment explaining this logic |
||
return True | ||
|
||
# if df index is also a column it needs an index unless it's | ||
# also a primary key (otherwise there would be two indexes). | ||
# multi-index can use primary key if the left hand side matches. | ||
if self.index is not None and label in self.index: | ||
if self.keys is None: | ||
return True | ||
|
||
col_nr = self.index.index(label) + 1 | ||
if self.keys[:col_nr] != self.index[:col_nr]: | ||
return True | ||
|
||
return False | ||
|
||
def _get_column_names_and_types(self, dtype_mapper): | ||
column_names_and_types = [] | ||
if self.index is not None: | ||
for i, idx_label in enumerate(self.index): | ||
idx_type = dtype_mapper( | ||
self.frame.index.get_level_values(i)) | ||
column_names_and_types.append((idx_label, idx_type, True)) | ||
indexed = self._is_column_indexed(idx_label) | ||
column_names_and_types.append((idx_label, idx_type, indexed)) | ||
|
||
column_names_and_types += [ | ||
(text_type(self.frame.columns[i]), | ||
dtype_mapper(self.frame.iloc[:, i]), | ||
False) | ||
self._is_column_indexed(text_type(self.frame.columns[i]))) | ||
for i in range(len(self.frame.columns)) | ||
] | ||
|
||
|
@@ -1098,7 +1123,8 @@ def read_query(self, sql, index_col=None, coerce_float=True, | |
read_sql = read_query | ||
|
||
def to_sql(self, frame, name, if_exists='fail', index=True, | ||
index_label=None, schema=None, chunksize=None, dtype=None): | ||
index_label=None, schema=None, chunksize=None, dtype=None, | ||
indexes=None): | ||
""" | ||
Write records stored in a DataFrame to a SQL database. | ||
|
||
|
@@ -1142,7 +1168,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, | |
|
||
table = SQLTable(name, self, frame=frame, index=index, | ||
if_exists=if_exists, index_label=index_label, | ||
schema=schema, dtype=dtype) | ||
schema=schema, dtype=dtype, indexes=indexes) | ||
table.create() | ||
table.insert(chunksize) | ||
if (not name.isdigit() and not name.islower()): | ||
|
@@ -1456,7 +1482,8 @@ def _fetchall_as_list(self, cur): | |
return result | ||
|
||
def to_sql(self, frame, name, if_exists='fail', index=True, | ||
index_label=None, schema=None, chunksize=None, dtype=None): | ||
index_label=None, schema=None, chunksize=None, dtype=None, | ||
indexes=None): | ||
""" | ||
Write records stored in a DataFrame to a SQL database. | ||
|
||
|
@@ -1497,7 +1524,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, | |
|
||
table = SQLiteTable(name, self, frame=frame, index=index, | ||
if_exists=if_exists, index_label=index_label, | ||
dtype=dtype) | ||
dtype=dtype, indexes=indexes) | ||
table.create() | ||
table.insert(chunksize) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -310,6 +310,14 @@ def _load_test3_data(self): | |
|
||
self.test_frame3 = DataFrame(data, columns=columns) | ||
|
||
def _load_test4_data(self): | ||
n = 10 | ||
colors = np.random.choice(['red', 'green'], size=n) | ||
foods = np.random.choice(['eggs', 'ham'], size=n) | ||
index = pd.MultiIndex.from_arrays([colors, foods], | ||
names=['color', 'food']) | ||
self.test_frame4 = DataFrame(np.random.randn(n, 2), index=index) | ||
|
||
def _load_raw_sql(self): | ||
self.drop_table('types_test_data') | ||
self._get_exec().execute(SQL_STRINGS['create_test_types'][self.flavor]) | ||
|
@@ -513,6 +521,7 @@ def setUp(self): | |
self._load_test1_data() | ||
self._load_test2_data() | ||
self._load_test3_data() | ||
self._load_test4_data() | ||
self._load_raw_sql() | ||
|
||
def test_read_sql_iris(self): | ||
|
@@ -930,7 +939,7 @@ def test_warning_case_insensitive_table_name(self): | |
def _get_index_columns(self, tbl_name): | ||
from sqlalchemy.engine import reflection | ||
insp = reflection.Inspector.from_engine(self.conn) | ||
ixs = insp.get_indexes('test_index_saved') | ||
ixs = insp.get_indexes(tbl_name) | ||
ixs = [i['column_names'] for i in ixs] | ||
return ixs | ||
|
||
|
@@ -963,6 +972,66 @@ def test_to_sql_read_sql_with_database_uri(self): | |
tm.assert_frame_equal(test_frame1, test_frame3) | ||
tm.assert_frame_equal(test_frame1, test_frame4) | ||
|
||
def test_to_sql_column_indexes(self): | ||
temp_frame = DataFrame({'col1': range(4), 'col2': range(4)}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add the issue number as a comment |
||
sql.to_sql(temp_frame, 'test_to_sql_column_indexes', self.conn, | ||
index=False, if_exists='replace', indexes=['col1', 'col2']) | ||
ix_cols = self._get_index_columns('test_to_sql_column_indexes') | ||
self.assertEqual(sorted(ix_cols), [['col1'], ['col2']], | ||
"columns are not correctly indexes") | ||
|
||
def test_sqltable_key_and_multiindex_no_pk(self): | ||
db = sql.SQLDatabase(self.conn) | ||
table = sql.SQLTable('test_sqltable_key_and_multiindex_no_pk', db, | ||
frame=self.test_frame4, index=True) | ||
metadata = table.table.tometadata(table.pd_sql.meta) | ||
indexed_columns = [e.columns.keys() for e in metadata.indexes] | ||
primary_keys = metadata.primary_key.columns.keys() | ||
self.assertListEqual([['color'], ['food']], sorted(indexed_columns), | ||
"Wrong secondary indexes") | ||
self.assertListEqual([], primary_keys, | ||
"There should be no primary keys") | ||
|
||
def test_sqltable_key_and_multiindex_one_pk(self): | ||
db = sql.SQLDatabase(self.conn) | ||
table = sql.SQLTable('test_sqltable_key_and_multiindex_one_pk', db, | ||
frame=self.test_frame4, index=True, | ||
keys=['color']) | ||
metadata = table.table.tometadata(table.pd_sql.meta) | ||
indexed_columns = [e.columns.keys() for e in metadata.indexes] | ||
primary_keys = metadata.primary_key.columns.keys() | ||
self.assertListEqual([['food']], indexed_columns, | ||
"Wrong secondary indexes") | ||
self.assertListEqual(['color'], primary_keys, | ||
"Wrong primary keys") | ||
|
||
def test_sqltable_key_and_multiindex_two_pk(self): | ||
db = sql.SQLDatabase(self.conn) | ||
table = sql.SQLTable('test_sqltable_key_and_multiindex_two_pk', db, | ||
frame=self.test_frame4, index=True, | ||
keys=['color', 'food']) | ||
metadata = table.table.tometadata(table.pd_sql.meta) | ||
indexed_columns = [e.columns.keys() for e in metadata.indexes] | ||
primary_keys = metadata.primary_key.columns.keys() | ||
self.assertListEqual([], indexed_columns, | ||
"There should be no secondary indexes") | ||
self.assertListEqual(['color', 'food'], primary_keys, | ||
"Wrong primary keys") | ||
|
||
def test_sqltable_no_double_key_and_index_index(self): | ||
temp_frame = DataFrame({'col1': range(4), 'col2': range(4)}) | ||
db = sql.SQLDatabase(self.conn) | ||
table = sql.SQLTable('test_sqltable_no_double_key_and_index_index', db, | ||
frame=temp_frame, index=True, index_label='id', | ||
keys=['id'], indexes=['col1', 'col2']) | ||
table_metadata = table.table.tometadata(table.pd_sql.meta) | ||
indexed_columns = [e.columns.keys() for e in table_metadata.indexes] | ||
self.assertNotIn('id', indexed_columns, | ||
"Secondary Index found for primary key") | ||
|
||
self.assertListEqual(['id'], table_metadata.primary_key.columns.keys(), | ||
"Primary key missing from table") | ||
|
||
def _make_iris_table_metadata(self): | ||
sa = sqlalchemy | ||
metadata = sa.MetaData() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are the params documented here?