19
19
from pandas .core .base import PandasObject
20
20
from pandas .tseries .tools import to_datetime
21
21
22
+ from contextlib import contextmanager
22
23
23
24
class SQLAlchemyRequired (ImportError ):
24
25
pass
@@ -637,13 +638,9 @@ def insert_data(self):
637
638
638
639
return column_names , data_list
639
640
640
- def get_session (self ):
641
- con = self .pd_sql .engine .connect ()
642
- return con .begin ()
643
-
644
- def _execute_insert (self , trans , keys , data_iter ):
641
+ def _execute_insert (self , conn , keys , data_iter ):
645
642
data = [dict ( (k , v ) for k , v in zip (keys , row ) ) for row in data_iter ]
646
- trans . connection .execute (self .insert_statement (), data )
643
+ conn .execute (self .insert_statement (), data )
647
644
648
645
def insert (self , chunksize = None ):
649
646
keys , data_list = self .insert_data ()
@@ -653,15 +650,15 @@ def insert(self, chunksize=None):
653
650
chunksize = nrows
654
651
chunks = int (nrows / chunksize ) + 1
655
652
656
- with self .get_session () as trans :
653
+ with self .pd_sql . run_transaction () as conn :
657
654
for i in range (chunks ):
658
655
start_i = i * chunksize
659
656
end_i = min ((i + 1 ) * chunksize , nrows )
660
657
if start_i >= end_i :
661
658
break
662
659
663
660
chunk_iter = zip (* [arr [start_i :end_i ] for arr in data_list ])
664
- self ._execute_insert (trans , keys , chunk_iter )
661
+ self ._execute_insert (conn , keys , chunk_iter )
665
662
666
663
def read (self , coerce_float = True , parse_dates = None , columns = None ):
667
664
@@ -884,6 +881,9 @@ def __init__(self, engine, schema=None, meta=None):
884
881
885
882
self .meta = meta
886
883
884
+ def run_transaction (self ):
885
+ return self .engine .begin ()
886
+
887
887
def execute (self , * args , ** kwargs ):
888
888
"""Simple passthrough to SQLAlchemy engine"""
889
889
return self .engine .execute (* args , ** kwargs )
@@ -1017,9 +1017,9 @@ def sql_schema(self):
1017
1017
return str (";\n " .join (self .table ))
1018
1018
1019
1019
def _execute_create (self ):
1020
- with self .get_session () :
1020
+ with self .pd_sql . run_transaction () as conn :
1021
1021
for stmt in self .table :
1022
- self . pd_sql .execute (stmt )
1022
+ conn .execute (stmt )
1023
1023
1024
1024
def insert_statement (self ):
1025
1025
names = list (map (str , self .frame .columns ))
@@ -1038,12 +1038,9 @@ def insert_statement(self):
1038
1038
self .name , col_names , wildcards )
1039
1039
return insert_statement
1040
1040
1041
- def get_session (self ):
1042
- return self .pd_sql .con
1043
-
1044
- def _execute_insert (self , trans , keys , data_iter ):
1041
+ def _execute_insert (self , conn , keys , data_iter ):
1045
1042
data_list = list (data_iter )
1046
- trans .executemany (self .insert_statement (), data_list )
1043
+ conn .executemany (self .insert_statement (), data_list )
1047
1044
1048
1045
def _create_table_setup (self ):
1049
1046
"""Return a list of SQL statement that create a table reflecting the
@@ -1125,6 +1122,17 @@ def __init__(self, con, flavor, is_cursor=False):
1125
1122
else :
1126
1123
self .flavor = flavor
1127
1124
1125
+ @contextmanager
1126
+ def run_transaction (self ):
1127
+ cur = self .con .cursor ()
1128
+ try :
1129
+ yield cur
1130
+ self .con .commit ()
1131
+ except :
1132
+ self .con .rollback ()
1133
+ finally :
1134
+ cur .close ()
1135
+
1128
1136
def execute (self , * args , ** kwargs ):
1129
1137
if self .is_cursor :
1130
1138
cur = self .con
0 commit comments