Skip to content

Commit b4ce396

Browse files
committed
Merge pull request #10717 from sinhrks/plot_cln
CLN: plotting cleanups for groupby plotting
2 parents 9ef7ebb + d558f16 commit b4ce396

File tree

1 file changed

+94
-109
lines changed

1 file changed

+94
-109
lines changed

pandas/tools/plotting.py

Lines changed: 94 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,12 @@ class MPLPlot(object):
774774
data :
775775
776776
"""
777-
_kind = 'base'
777+
778+
@property
779+
def _kind(self):
780+
"""Specify kind str. Must be overridden in child class"""
781+
raise NotImplementedError
782+
778783
_layout_type = 'vertical'
779784
_default_rot = 0
780785
orientation = None
@@ -938,7 +943,10 @@ def generate(self):
938943
self._make_plot()
939944
self._add_table()
940945
self._make_legend()
941-
self._post_plot_logic()
946+
947+
for ax in self.axes:
948+
self._post_plot_logic_common(ax, self.data)
949+
self._post_plot_logic(ax, self.data)
942950
self._adorn_subplots()
943951

944952
def _args_adjust(self):
@@ -1055,12 +1063,34 @@ def _add_table(self):
10551063
ax = self._get_ax(0)
10561064
table(ax, data)
10571065

1058-
def _post_plot_logic(self):
1066+
def _post_plot_logic_common(self, ax, data):
1067+
"""Common post process for each axes"""
1068+
labels = [com.pprint_thing(key) for key in data.index]
1069+
labels = dict(zip(range(len(data.index)), labels))
1070+
1071+
if self.orientation == 'vertical' or self.orientation is None:
1072+
if self._need_to_set_index:
1073+
xticklabels = [labels.get(x, '') for x in ax.get_xticks()]
1074+
ax.set_xticklabels(xticklabels)
1075+
self._apply_axis_properties(ax.xaxis, rot=self.rot,
1076+
fontsize=self.fontsize)
1077+
self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize)
1078+
elif self.orientation == 'horizontal':
1079+
if self._need_to_set_index:
1080+
yticklabels = [labels.get(y, '') for y in ax.get_yticks()]
1081+
ax.set_yticklabels(yticklabels)
1082+
self._apply_axis_properties(ax.yaxis, rot=self.rot,
1083+
fontsize=self.fontsize)
1084+
self._apply_axis_properties(ax.xaxis, fontsize=self.fontsize)
1085+
else: # pragma no cover
1086+
raise ValueError
1087+
1088+
def _post_plot_logic(self, ax, data):
1089+
"""Post process for each axes. Overridden in child classes"""
10591090
pass
10601091

10611092
def _adorn_subplots(self):
1062-
to_adorn = self.axes
1063-
1093+
"""Common post process unrelated to data"""
10641094
if len(self.axes) > 0:
10651095
all_axes = self._get_axes()
10661096
nrows, ncols = self._get_axes_layout()
@@ -1069,7 +1099,7 @@ def _adorn_subplots(self):
10691099
ncols=ncols, sharex=self.sharex,
10701100
sharey=self.sharey)
10711101

1072-
for ax in to_adorn:
1102+
for ax in self.axes:
10731103
if self.yticks is not None:
10741104
ax.set_yticks(self.yticks)
10751105

@@ -1090,25 +1120,6 @@ def _adorn_subplots(self):
10901120
else:
10911121
self.axes[0].set_title(self.title)
10921122

1093-
labels = [com.pprint_thing(key) for key in self.data.index]
1094-
labels = dict(zip(range(len(self.data.index)), labels))
1095-
1096-
for ax in self.axes:
1097-
if self.orientation == 'vertical' or self.orientation is None:
1098-
if self._need_to_set_index:
1099-
xticklabels = [labels.get(x, '') for x in ax.get_xticks()]
1100-
ax.set_xticklabels(xticklabels)
1101-
self._apply_axis_properties(ax.xaxis, rot=self.rot,
1102-
fontsize=self.fontsize)
1103-
self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize)
1104-
elif self.orientation == 'horizontal':
1105-
if self._need_to_set_index:
1106-
yticklabels = [labels.get(y, '') for y in ax.get_yticks()]
1107-
ax.set_yticklabels(yticklabels)
1108-
self._apply_axis_properties(ax.yaxis, rot=self.rot,
1109-
fontsize=self.fontsize)
1110-
self._apply_axis_properties(ax.xaxis, fontsize=self.fontsize)
1111-
11121123
def _apply_axis_properties(self, axis, rot=None, fontsize=None):
11131124
labels = axis.get_majorticklabels() + axis.get_minorticklabels()
11141125
for label in labels:
@@ -1419,34 +1430,48 @@ def _get_axes_layout(self):
14191430
y_set.add(points[0][1])
14201431
return (len(y_set), len(x_set))
14211432

1422-
class ScatterPlot(MPLPlot):
1423-
_kind = 'scatter'
1433+
1434+
class PlanePlot(MPLPlot):
1435+
"""
1436+
Abstract class for plotting on plane, currently scatter and hexbin.
1437+
"""
1438+
14241439
_layout_type = 'single'
14251440

1426-
def __init__(self, data, x, y, c=None, **kwargs):
1441+
def __init__(self, data, x, y, **kwargs):
14271442
MPLPlot.__init__(self, data, **kwargs)
14281443
if x is None or y is None:
1429-
raise ValueError( 'scatter requires and x and y column')
1444+
raise ValueError(self._kind + ' requires and x and y column')
14301445
if com.is_integer(x) and not self.data.columns.holds_integer():
14311446
x = self.data.columns[x]
14321447
if com.is_integer(y) and not self.data.columns.holds_integer():
14331448
y = self.data.columns[y]
1434-
if com.is_integer(c) and not self.data.columns.holds_integer():
1435-
c = self.data.columns[c]
14361449
self.x = x
14371450
self.y = y
1438-
self.c = c
14391451

14401452
@property
14411453
def nseries(self):
14421454
return 1
14431455

1456+
def _post_plot_logic(self, ax, data):
1457+
x, y = self.x, self.y
1458+
ax.set_ylabel(com.pprint_thing(y))
1459+
ax.set_xlabel(com.pprint_thing(x))
1460+
1461+
1462+
class ScatterPlot(PlanePlot):
1463+
_kind = 'scatter'
1464+
1465+
def __init__(self, data, x, y, c=None, **kwargs):
1466+
super(ScatterPlot, self).__init__(data, x, y, **kwargs)
1467+
if com.is_integer(c) and not self.data.columns.holds_integer():
1468+
c = self.data.columns[c]
1469+
self.c = c
1470+
14441471
def _make_plot(self):
14451472
import matplotlib as mpl
14461473
mpl_ge_1_3_1 = str(mpl.__version__) >= LooseVersion('1.3.1')
14471474

1448-
import matplotlib.pyplot as plt
1449-
14501475
x, y, c, data = self.x, self.y, self.c, self.data
14511476
ax = self.axes[0]
14521477

@@ -1457,7 +1482,7 @@ def _make_plot(self):
14571482

14581483
# pandas uses colormap, matplotlib uses cmap.
14591484
cmap = self.colormap or 'Greys'
1460-
cmap = plt.cm.get_cmap(cmap)
1485+
cmap = self.plt.cm.get_cmap(cmap)
14611486

14621487
if c is None:
14631488
c_values = self.plt.rcParams['patch.facecolor']
@@ -1491,46 +1516,22 @@ def _make_plot(self):
14911516
err_kwds['ecolor'] = scatter.get_facecolor()[0]
14921517
ax.errorbar(data[x].values, data[y].values, linestyle='none', **err_kwds)
14931518

1494-
def _post_plot_logic(self):
1495-
ax = self.axes[0]
1496-
x, y = self.x, self.y
1497-
ax.set_ylabel(com.pprint_thing(y))
1498-
ax.set_xlabel(com.pprint_thing(x))
1499-
15001519

1501-
class HexBinPlot(MPLPlot):
1520+
class HexBinPlot(PlanePlot):
15021521
_kind = 'hexbin'
1503-
_layout_type = 'single'
15041522

15051523
def __init__(self, data, x, y, C=None, **kwargs):
1506-
MPLPlot.__init__(self, data, **kwargs)
1507-
1508-
if x is None or y is None:
1509-
raise ValueError('hexbin requires and x and y column')
1510-
if com.is_integer(x) and not self.data.columns.holds_integer():
1511-
x = self.data.columns[x]
1512-
if com.is_integer(y) and not self.data.columns.holds_integer():
1513-
y = self.data.columns[y]
1514-
1524+
super(HexBinPlot, self).__init__(data, x, y, **kwargs)
15151525
if com.is_integer(C) and not self.data.columns.holds_integer():
15161526
C = self.data.columns[C]
1517-
1518-
self.x = x
1519-
self.y = y
15201527
self.C = C
15211528

1522-
@property
1523-
def nseries(self):
1524-
return 1
1525-
15261529
def _make_plot(self):
1527-
import matplotlib.pyplot as plt
1528-
15291530
x, y, data, C = self.x, self.y, self.data, self.C
15301531
ax = self.axes[0]
15311532
# pandas uses colormap, matplotlib uses cmap.
15321533
cmap = self.colormap or 'BuGn'
1533-
cmap = plt.cm.get_cmap(cmap)
1534+
cmap = self.plt.cm.get_cmap(cmap)
15341535
cb = self.kwds.pop('colorbar', True)
15351536

15361537
if C is None:
@@ -1547,12 +1548,6 @@ def _make_plot(self):
15471548
def _make_legend(self):
15481549
pass
15491550

1550-
def _post_plot_logic(self):
1551-
ax = self.axes[0]
1552-
x, y = self.x, self.y
1553-
ax.set_ylabel(com.pprint_thing(y))
1554-
ax.set_xlabel(com.pprint_thing(x))
1555-
15561551

15571552
class LinePlot(MPLPlot):
15581553
_kind = 'line'
@@ -1685,26 +1680,23 @@ def _update_stacker(cls, ax, stacking_id, values):
16851680
elif (values <= 0).all():
16861681
ax._stacker_neg_prior[stacking_id] += values
16871682

1688-
def _post_plot_logic(self):
1689-
df = self.data
1690-
1683+
def _post_plot_logic(self, ax, data):
16911684
condition = (not self._use_dynamic_x()
1692-
and df.index.is_all_dates
1685+
and data.index.is_all_dates
16931686
and not self.subplots
16941687
or (self.subplots and self.sharex))
16951688

16961689
index_name = self._get_index_name()
16971690

1698-
for ax in self.axes:
1699-
if condition:
1700-
# irregular TS rotated 30 deg. by default
1701-
# probably a better place to check / set this.
1702-
if not self._rot_set:
1703-
self.rot = 30
1704-
format_date_labels(ax, rot=self.rot)
1691+
if condition:
1692+
# irregular TS rotated 30 deg. by default
1693+
# probably a better place to check / set this.
1694+
if not self._rot_set:
1695+
self.rot = 30
1696+
format_date_labels(ax, rot=self.rot)
17051697

1706-
if index_name is not None and self.use_index:
1707-
ax.set_xlabel(index_name)
1698+
if index_name is not None and self.use_index:
1699+
ax.set_xlabel(index_name)
17081700

17091701

17101702
class AreaPlot(LinePlot):
@@ -1758,16 +1750,14 @@ def _add_legend_handle(self, handle, label, index=None):
17581750
handle = Rectangle((0, 0), 1, 1, fc=handle.get_color(), alpha=alpha)
17591751
LinePlot._add_legend_handle(self, handle, label, index=index)
17601752

1761-
def _post_plot_logic(self):
1762-
LinePlot._post_plot_logic(self)
1753+
def _post_plot_logic(self, ax, data):
1754+
LinePlot._post_plot_logic(self, ax, data)
17631755

17641756
if self.ylim is None:
1765-
if (self.data >= 0).all().all():
1766-
for ax in self.axes:
1767-
ax.set_ylim(0, None)
1768-
elif (self.data <= 0).all().all():
1769-
for ax in self.axes:
1770-
ax.set_ylim(None, 0)
1757+
if (data >= 0).all().all():
1758+
ax.set_ylim(0, None)
1759+
elif (data <= 0).all().all():
1760+
ax.set_ylim(None, 0)
17711761

17721762

17731763
class BarPlot(MPLPlot):
@@ -1865,19 +1855,17 @@ def _make_plot(self):
18651855
start=start, label=label, log=self.log, **kwds)
18661856
self._add_legend_handle(rect, label, index=i)
18671857

1868-
def _post_plot_logic(self):
1869-
for ax in self.axes:
1870-
if self.use_index:
1871-
str_index = [com.pprint_thing(key) for key in self.data.index]
1872-
else:
1873-
str_index = [com.pprint_thing(key) for key in
1874-
range(self.data.shape[0])]
1875-
name = self._get_index_name()
1858+
def _post_plot_logic(self, ax, data):
1859+
if self.use_index:
1860+
str_index = [com.pprint_thing(key) for key in data.index]
1861+
else:
1862+
str_index = [com.pprint_thing(key) for key in range(data.shape[0])]
1863+
name = self._get_index_name()
18761864

1877-
s_edge = self.ax_pos[0] - 0.25 + self.lim_offset
1878-
e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset
1865+
s_edge = self.ax_pos[0] - 0.25 + self.lim_offset
1866+
e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset
18791867

1880-
self._decorate_ticks(ax, name, str_index, s_edge, e_edge)
1868+
self._decorate_ticks(ax, name, str_index, s_edge, e_edge)
18811869

18821870
def _decorate_ticks(self, ax, name, ticklabels, start_edge, end_edge):
18831871
ax.set_xlim((start_edge, end_edge))
@@ -1975,13 +1963,11 @@ def _make_plot_keywords(self, kwds, y):
19751963
kwds['bins'] = self.bins
19761964
return kwds
19771965

1978-
def _post_plot_logic(self):
1966+
def _post_plot_logic(self, ax, data):
19791967
if self.orientation == 'horizontal':
1980-
for ax in self.axes:
1981-
ax.set_xlabel('Frequency')
1968+
ax.set_xlabel('Frequency')
19821969
else:
1983-
for ax in self.axes:
1984-
ax.set_ylabel('Frequency')
1970+
ax.set_ylabel('Frequency')
19851971

19861972
@property
19871973
def orientation(self):
@@ -2038,9 +2024,8 @@ def _make_plot_keywords(self, kwds, y):
20382024
kwds['ind'] = self._get_ind(y)
20392025
return kwds
20402026

2041-
def _post_plot_logic(self):
2042-
for ax in self.axes:
2043-
ax.set_ylabel('Density')
2027+
def _post_plot_logic(self, ax, data):
2028+
ax.set_ylabel('Density')
20442029

20452030

20462031
class PiePlot(MPLPlot):
@@ -2242,7 +2227,7 @@ def _set_ticklabels(self, ax, labels):
22422227
def _make_legend(self):
22432228
pass
22442229

2245-
def _post_plot_logic(self):
2230+
def _post_plot_logic(self, ax, data):
22462231
pass
22472232

22482233
@property

0 commit comments

Comments
 (0)