@@ -774,7 +774,12 @@ class MPLPlot(object):
774
774
data :
775
775
776
776
"""
777
- _kind = 'base'
777
+
778
+ @property
779
+ def _kind (self ):
780
+ """Specify kind str. Must be overridden in child class"""
781
+ raise NotImplementedError
782
+
778
783
_layout_type = 'vertical'
779
784
_default_rot = 0
780
785
orientation = None
@@ -938,7 +943,10 @@ def generate(self):
938
943
self ._make_plot ()
939
944
self ._add_table ()
940
945
self ._make_legend ()
941
- self ._post_plot_logic ()
946
+
947
+ for ax in self .axes :
948
+ self ._post_plot_logic_common (ax , self .data )
949
+ self ._post_plot_logic (ax , self .data )
942
950
self ._adorn_subplots ()
943
951
944
952
def _args_adjust (self ):
@@ -1055,12 +1063,34 @@ def _add_table(self):
1055
1063
ax = self ._get_ax (0 )
1056
1064
table (ax , data )
1057
1065
1058
- def _post_plot_logic (self ):
1066
+ def _post_plot_logic_common (self , ax , data ):
1067
+ """Common post process for each axes"""
1068
+ labels = [com .pprint_thing (key ) for key in data .index ]
1069
+ labels = dict (zip (range (len (data .index )), labels ))
1070
+
1071
+ if self .orientation == 'vertical' or self .orientation is None :
1072
+ if self ._need_to_set_index :
1073
+ xticklabels = [labels .get (x , '' ) for x in ax .get_xticks ()]
1074
+ ax .set_xticklabels (xticklabels )
1075
+ self ._apply_axis_properties (ax .xaxis , rot = self .rot ,
1076
+ fontsize = self .fontsize )
1077
+ self ._apply_axis_properties (ax .yaxis , fontsize = self .fontsize )
1078
+ elif self .orientation == 'horizontal' :
1079
+ if self ._need_to_set_index :
1080
+ yticklabels = [labels .get (y , '' ) for y in ax .get_yticks ()]
1081
+ ax .set_yticklabels (yticklabels )
1082
+ self ._apply_axis_properties (ax .yaxis , rot = self .rot ,
1083
+ fontsize = self .fontsize )
1084
+ self ._apply_axis_properties (ax .xaxis , fontsize = self .fontsize )
1085
+ else : # pragma no cover
1086
+ raise ValueError
1087
+
1088
+ def _post_plot_logic (self , ax , data ):
1089
+ """Post process for each axes. Overridden in child classes"""
1059
1090
pass
1060
1091
1061
1092
def _adorn_subplots (self ):
1062
- to_adorn = self .axes
1063
-
1093
+ """Common post process unrelated to data"""
1064
1094
if len (self .axes ) > 0 :
1065
1095
all_axes = self ._get_axes ()
1066
1096
nrows , ncols = self ._get_axes_layout ()
@@ -1069,7 +1099,7 @@ def _adorn_subplots(self):
1069
1099
ncols = ncols , sharex = self .sharex ,
1070
1100
sharey = self .sharey )
1071
1101
1072
- for ax in to_adorn :
1102
+ for ax in self . axes :
1073
1103
if self .yticks is not None :
1074
1104
ax .set_yticks (self .yticks )
1075
1105
@@ -1090,25 +1120,6 @@ def _adorn_subplots(self):
1090
1120
else :
1091
1121
self .axes [0 ].set_title (self .title )
1092
1122
1093
- labels = [com .pprint_thing (key ) for key in self .data .index ]
1094
- labels = dict (zip (range (len (self .data .index )), labels ))
1095
-
1096
- for ax in self .axes :
1097
- if self .orientation == 'vertical' or self .orientation is None :
1098
- if self ._need_to_set_index :
1099
- xticklabels = [labels .get (x , '' ) for x in ax .get_xticks ()]
1100
- ax .set_xticklabels (xticklabels )
1101
- self ._apply_axis_properties (ax .xaxis , rot = self .rot ,
1102
- fontsize = self .fontsize )
1103
- self ._apply_axis_properties (ax .yaxis , fontsize = self .fontsize )
1104
- elif self .orientation == 'horizontal' :
1105
- if self ._need_to_set_index :
1106
- yticklabels = [labels .get (y , '' ) for y in ax .get_yticks ()]
1107
- ax .set_yticklabels (yticklabels )
1108
- self ._apply_axis_properties (ax .yaxis , rot = self .rot ,
1109
- fontsize = self .fontsize )
1110
- self ._apply_axis_properties (ax .xaxis , fontsize = self .fontsize )
1111
-
1112
1123
def _apply_axis_properties (self , axis , rot = None , fontsize = None ):
1113
1124
labels = axis .get_majorticklabels () + axis .get_minorticklabels ()
1114
1125
for label in labels :
@@ -1419,34 +1430,48 @@ def _get_axes_layout(self):
1419
1430
y_set .add (points [0 ][1 ])
1420
1431
return (len (y_set ), len (x_set ))
1421
1432
1422
- class ScatterPlot (MPLPlot ):
1423
- _kind = 'scatter'
1433
+
1434
+ class PlanePlot (MPLPlot ):
1435
+ """
1436
+ Abstract class for plotting on plane, currently scatter and hexbin.
1437
+ """
1438
+
1424
1439
_layout_type = 'single'
1425
1440
1426
- def __init__ (self , data , x , y , c = None , ** kwargs ):
1441
+ def __init__ (self , data , x , y , ** kwargs ):
1427
1442
MPLPlot .__init__ (self , data , ** kwargs )
1428
1443
if x is None or y is None :
1429
- raise ValueError ( 'scatter requires and x and y column' )
1444
+ raise ValueError (self . _kind + ' requires and x and y column' )
1430
1445
if com .is_integer (x ) and not self .data .columns .holds_integer ():
1431
1446
x = self .data .columns [x ]
1432
1447
if com .is_integer (y ) and not self .data .columns .holds_integer ():
1433
1448
y = self .data .columns [y ]
1434
- if com .is_integer (c ) and not self .data .columns .holds_integer ():
1435
- c = self .data .columns [c ]
1436
1449
self .x = x
1437
1450
self .y = y
1438
- self .c = c
1439
1451
1440
1452
@property
1441
1453
def nseries (self ):
1442
1454
return 1
1443
1455
1456
+ def _post_plot_logic (self , ax , data ):
1457
+ x , y = self .x , self .y
1458
+ ax .set_ylabel (com .pprint_thing (y ))
1459
+ ax .set_xlabel (com .pprint_thing (x ))
1460
+
1461
+
1462
+ class ScatterPlot (PlanePlot ):
1463
+ _kind = 'scatter'
1464
+
1465
+ def __init__ (self , data , x , y , c = None , ** kwargs ):
1466
+ super (ScatterPlot , self ).__init__ (data , x , y , ** kwargs )
1467
+ if com .is_integer (c ) and not self .data .columns .holds_integer ():
1468
+ c = self .data .columns [c ]
1469
+ self .c = c
1470
+
1444
1471
def _make_plot (self ):
1445
1472
import matplotlib as mpl
1446
1473
mpl_ge_1_3_1 = str (mpl .__version__ ) >= LooseVersion ('1.3.1' )
1447
1474
1448
- import matplotlib .pyplot as plt
1449
-
1450
1475
x , y , c , data = self .x , self .y , self .c , self .data
1451
1476
ax = self .axes [0 ]
1452
1477
@@ -1457,7 +1482,7 @@ def _make_plot(self):
1457
1482
1458
1483
# pandas uses colormap, matplotlib uses cmap.
1459
1484
cmap = self .colormap or 'Greys'
1460
- cmap = plt .cm .get_cmap (cmap )
1485
+ cmap = self . plt .cm .get_cmap (cmap )
1461
1486
1462
1487
if c is None :
1463
1488
c_values = self .plt .rcParams ['patch.facecolor' ]
@@ -1491,46 +1516,22 @@ def _make_plot(self):
1491
1516
err_kwds ['ecolor' ] = scatter .get_facecolor ()[0 ]
1492
1517
ax .errorbar (data [x ].values , data [y ].values , linestyle = 'none' , ** err_kwds )
1493
1518
1494
- def _post_plot_logic (self ):
1495
- ax = self .axes [0 ]
1496
- x , y = self .x , self .y
1497
- ax .set_ylabel (com .pprint_thing (y ))
1498
- ax .set_xlabel (com .pprint_thing (x ))
1499
-
1500
1519
1501
- class HexBinPlot (MPLPlot ):
1520
+ class HexBinPlot (PlanePlot ):
1502
1521
_kind = 'hexbin'
1503
- _layout_type = 'single'
1504
1522
1505
1523
def __init__ (self , data , x , y , C = None , ** kwargs ):
1506
- MPLPlot .__init__ (self , data , ** kwargs )
1507
-
1508
- if x is None or y is None :
1509
- raise ValueError ('hexbin requires and x and y column' )
1510
- if com .is_integer (x ) and not self .data .columns .holds_integer ():
1511
- x = self .data .columns [x ]
1512
- if com .is_integer (y ) and not self .data .columns .holds_integer ():
1513
- y = self .data .columns [y ]
1514
-
1524
+ super (HexBinPlot , self ).__init__ (data , x , y , ** kwargs )
1515
1525
if com .is_integer (C ) and not self .data .columns .holds_integer ():
1516
1526
C = self .data .columns [C ]
1517
-
1518
- self .x = x
1519
- self .y = y
1520
1527
self .C = C
1521
1528
1522
- @property
1523
- def nseries (self ):
1524
- return 1
1525
-
1526
1529
def _make_plot (self ):
1527
- import matplotlib .pyplot as plt
1528
-
1529
1530
x , y , data , C = self .x , self .y , self .data , self .C
1530
1531
ax = self .axes [0 ]
1531
1532
# pandas uses colormap, matplotlib uses cmap.
1532
1533
cmap = self .colormap or 'BuGn'
1533
- cmap = plt .cm .get_cmap (cmap )
1534
+ cmap = self . plt .cm .get_cmap (cmap )
1534
1535
cb = self .kwds .pop ('colorbar' , True )
1535
1536
1536
1537
if C is None :
@@ -1547,12 +1548,6 @@ def _make_plot(self):
1547
1548
def _make_legend (self ):
1548
1549
pass
1549
1550
1550
- def _post_plot_logic (self ):
1551
- ax = self .axes [0 ]
1552
- x , y = self .x , self .y
1553
- ax .set_ylabel (com .pprint_thing (y ))
1554
- ax .set_xlabel (com .pprint_thing (x ))
1555
-
1556
1551
1557
1552
class LinePlot (MPLPlot ):
1558
1553
_kind = 'line'
@@ -1685,26 +1680,23 @@ def _update_stacker(cls, ax, stacking_id, values):
1685
1680
elif (values <= 0 ).all ():
1686
1681
ax ._stacker_neg_prior [stacking_id ] += values
1687
1682
1688
- def _post_plot_logic (self ):
1689
- df = self .data
1690
-
1683
+ def _post_plot_logic (self , ax , data ):
1691
1684
condition = (not self ._use_dynamic_x ()
1692
- and df .index .is_all_dates
1685
+ and data .index .is_all_dates
1693
1686
and not self .subplots
1694
1687
or (self .subplots and self .sharex ))
1695
1688
1696
1689
index_name = self ._get_index_name ()
1697
1690
1698
- for ax in self .axes :
1699
- if condition :
1700
- # irregular TS rotated 30 deg. by default
1701
- # probably a better place to check / set this.
1702
- if not self ._rot_set :
1703
- self .rot = 30
1704
- format_date_labels (ax , rot = self .rot )
1691
+ if condition :
1692
+ # irregular TS rotated 30 deg. by default
1693
+ # probably a better place to check / set this.
1694
+ if not self ._rot_set :
1695
+ self .rot = 30
1696
+ format_date_labels (ax , rot = self .rot )
1705
1697
1706
- if index_name is not None and self .use_index :
1707
- ax .set_xlabel (index_name )
1698
+ if index_name is not None and self .use_index :
1699
+ ax .set_xlabel (index_name )
1708
1700
1709
1701
1710
1702
class AreaPlot (LinePlot ):
@@ -1758,16 +1750,14 @@ def _add_legend_handle(self, handle, label, index=None):
1758
1750
handle = Rectangle ((0 , 0 ), 1 , 1 , fc = handle .get_color (), alpha = alpha )
1759
1751
LinePlot ._add_legend_handle (self , handle , label , index = index )
1760
1752
1761
- def _post_plot_logic (self ):
1762
- LinePlot ._post_plot_logic (self )
1753
+ def _post_plot_logic (self , ax , data ):
1754
+ LinePlot ._post_plot_logic (self , ax , data )
1763
1755
1764
1756
if self .ylim is None :
1765
- if (self .data >= 0 ).all ().all ():
1766
- for ax in self .axes :
1767
- ax .set_ylim (0 , None )
1768
- elif (self .data <= 0 ).all ().all ():
1769
- for ax in self .axes :
1770
- ax .set_ylim (None , 0 )
1757
+ if (data >= 0 ).all ().all ():
1758
+ ax .set_ylim (0 , None )
1759
+ elif (data <= 0 ).all ().all ():
1760
+ ax .set_ylim (None , 0 )
1771
1761
1772
1762
1773
1763
class BarPlot (MPLPlot ):
@@ -1865,19 +1855,17 @@ def _make_plot(self):
1865
1855
start = start , label = label , log = self .log , ** kwds )
1866
1856
self ._add_legend_handle (rect , label , index = i )
1867
1857
1868
- def _post_plot_logic (self ):
1869
- for ax in self .axes :
1870
- if self .use_index :
1871
- str_index = [com .pprint_thing (key ) for key in self .data .index ]
1872
- else :
1873
- str_index = [com .pprint_thing (key ) for key in
1874
- range (self .data .shape [0 ])]
1875
- name = self ._get_index_name ()
1858
+ def _post_plot_logic (self , ax , data ):
1859
+ if self .use_index :
1860
+ str_index = [com .pprint_thing (key ) for key in data .index ]
1861
+ else :
1862
+ str_index = [com .pprint_thing (key ) for key in range (data .shape [0 ])]
1863
+ name = self ._get_index_name ()
1876
1864
1877
- s_edge = self .ax_pos [0 ] - 0.25 + self .lim_offset
1878
- e_edge = self .ax_pos [- 1 ] + 0.25 + self .bar_width + self .lim_offset
1865
+ s_edge = self .ax_pos [0 ] - 0.25 + self .lim_offset
1866
+ e_edge = self .ax_pos [- 1 ] + 0.25 + self .bar_width + self .lim_offset
1879
1867
1880
- self ._decorate_ticks (ax , name , str_index , s_edge , e_edge )
1868
+ self ._decorate_ticks (ax , name , str_index , s_edge , e_edge )
1881
1869
1882
1870
def _decorate_ticks (self , ax , name , ticklabels , start_edge , end_edge ):
1883
1871
ax .set_xlim ((start_edge , end_edge ))
@@ -1975,13 +1963,11 @@ def _make_plot_keywords(self, kwds, y):
1975
1963
kwds ['bins' ] = self .bins
1976
1964
return kwds
1977
1965
1978
- def _post_plot_logic (self ):
1966
+ def _post_plot_logic (self , ax , data ):
1979
1967
if self .orientation == 'horizontal' :
1980
- for ax in self .axes :
1981
- ax .set_xlabel ('Frequency' )
1968
+ ax .set_xlabel ('Frequency' )
1982
1969
else :
1983
- for ax in self .axes :
1984
- ax .set_ylabel ('Frequency' )
1970
+ ax .set_ylabel ('Frequency' )
1985
1971
1986
1972
@property
1987
1973
def orientation (self ):
@@ -2038,9 +2024,8 @@ def _make_plot_keywords(self, kwds, y):
2038
2024
kwds ['ind' ] = self ._get_ind (y )
2039
2025
return kwds
2040
2026
2041
- def _post_plot_logic (self ):
2042
- for ax in self .axes :
2043
- ax .set_ylabel ('Density' )
2027
+ def _post_plot_logic (self , ax , data ):
2028
+ ax .set_ylabel ('Density' )
2044
2029
2045
2030
2046
2031
class PiePlot (MPLPlot ):
@@ -2242,7 +2227,7 @@ def _set_ticklabels(self, ax, labels):
2242
2227
def _make_legend (self ):
2243
2228
pass
2244
2229
2245
- def _post_plot_logic (self ):
2230
+ def _post_plot_logic (self , ax , data ):
2246
2231
pass
2247
2232
2248
2233
@property
0 commit comments