Skip to content

Commit c6d6540

Browse files
author
Jon M. Mease
committed
Add traces with subplots cleanup
- Add optional row/col params to add_traces - Add singular add_trace method with optional row/col params - Deprecate append_trace and remap to add_trace - Add row/col paras to the add_* figure methods
1 parent 2455dc1 commit c6d6540

File tree

3 files changed

+198
-20
lines changed

3 files changed

+198
-20
lines changed

codegen/datatypes.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def reindent_validator_description(validator, extra_indent):
257257
validator.description().strip().split('\n'))
258258

259259

260-
def add_constructor_params(buffer, subtype_nodes):
260+
def add_constructor_params(buffer, subtype_nodes, extras=()):
261261
"""
262262
Write datatype constructor params to a buffer
263263
@@ -267,6 +267,8 @@ def add_constructor_params(buffer, subtype_nodes):
267267
Buffer to write to
268268
subtype_nodes : list of PlotlyNode
269269
List of datatype nodes to be written as constructor params
270+
extras : list[str]
271+
List of extra parameters to include at the end of the params
270272
Returns
271273
-------
272274
None
@@ -275,13 +277,17 @@ def add_constructor_params(buffer, subtype_nodes):
275277
buffer.write(f""",
276278
{subtype_node.name_property}=None""")
277279

280+
for extra in extras:
281+
buffer.write(f""",
282+
{extra}=None""")
283+
278284
buffer.write(""",
279285
**kwargs""")
280286
buffer.write(f"""
281287
):""")
282288

283289

284-
def add_docstring(buffer, node, header):
290+
def add_docstring(buffer, node, header, extras=()):
285291
"""
286292
Write docstring for a compound datatype node
287293
@@ -328,6 +334,17 @@ def add_docstring(buffer, node, header):
328334
buffer.write(node.get_constructor_params_docstring(
329335
indent=8))
330336

337+
# Write any extras
338+
for p, v in extras:
339+
v_wrapped = '\n'.join(textwrap.wrap(
340+
v,
341+
width=79-12,
342+
initial_indent=' ' * 12,
343+
subsequent_indent=' ' * 12))
344+
buffer.write(f"""
345+
{p}
346+
{v_wrapped}""")
347+
331348
# Write return block and close docstring
332349
# --------------------------------------
333350
buffer.write(f"""

codegen/figure.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,22 @@ def __init__(self, data=None, layout=None, frames=None):
9191
def add_{trace_node.plotly_name}(self""")
9292

9393
# #### Function params####
94-
add_constructor_params(buffer, trace_node.child_datatypes)
94+
add_constructor_params(buffer, trace_node.child_datatypes,
95+
['row', 'col'])
9596

9697
# #### Docstring ####
9798
header = f"Add a new {trace_node.name_datatype_class} trace"
98-
add_docstring(buffer, trace_node, header)
99+
100+
extras = (('row : int or None (default)',
101+
'Subplot row index (starting from 1) for the trace to be '
102+
'added. Only valid if figure was created using '
103+
'`plotly.tools.make_subplots`'),
104+
('col : int or None (default)',
105+
'Subplot col index (starting from 1) for the trace to be '
106+
'added. Only valid if figure was created using '
107+
'`plotly.tools.make_subplots`'))
108+
109+
add_docstring(buffer, trace_node, header, extras=extras)
99110

100111
# #### Function body ####
101112
buffer.write(f"""
@@ -111,7 +122,7 @@ def add_{trace_node.plotly_name}(self""")
111122
**kwargs)""")
112123

113124
buffer.write(f"""
114-
return self.add_traces(new_trace)[0]""")
125+
return self.add_trace(new_trace, row=row, col=col)""")
115126

116127
# Return source string
117128
# --------------------

plotly/basedatatypes.py

Lines changed: 165 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections
22
import re
33
import typing as typ
4+
import warnings
45
from contextlib import contextmanager
56
from copy import deepcopy
67
from typing import Dict, Tuple, Union, Callable, List
@@ -791,14 +792,114 @@ def _set_in(d, key_path_str, v):
791792

792793
# Add traces
793794
# ----------
794-
def add_traces(self, data):
795+
@staticmethod
796+
def _raise_invalid_rows_cols(name, n, invalid):
797+
rows_err_msg = """
798+
If specified, the {name} parameter must be a list or tuple of integers
799+
of length {n} (The number of traces being added)
800+
801+
Received: {invalid}
802+
""".format(name=name, n=n, invalid=invalid)
803+
804+
raise ValueError(rows_err_msg)
805+
806+
@staticmethod
807+
def _validate_rows_cols(name, n, vals):
808+
if vals is None:
809+
pass
810+
elif isinstance(vals, (list, tuple)):
811+
if len(vals) != n:
812+
BaseFigure._raise_invalid_rows_cols(
813+
name=name, n=n, invalid=vals)
814+
815+
if [r for r in vals if not isinstance(r, int)]:
816+
BaseFigure._raise_invalid_rows_cols(
817+
name=name, n=n, invalid=vals)
818+
else:
819+
BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals)
820+
821+
def add_trace(self, trace, row=None, col=None):
795822
"""
796-
Add one or more traces to the figure
823+
Add a trace to the figure
797824
798825
Parameters
799826
----------
800-
data : BaseTraceType or dict or list[BaseTraceType or dict]
801-
A trace specification or list of trace specifications to be added.
827+
trace : BaseTraceType or dict
828+
Either:
829+
- An instances of a trace classe from the plotly.graph_objs
830+
package (e.g plotly.graph_objs.Scatter, plotly.graph_objs.Bar)
831+
- or a dicts where:
832+
833+
- The 'type' property specifies the trace type (e.g.
834+
'scatter', 'bar', 'area', etc.). If the dict has no 'type'
835+
property then 'scatter' is assumed.
836+
- All remaining properties are passed to the constructor
837+
of the specified trace type.
838+
839+
row : int or None (default)
840+
Subplot row index (starting from 1) for the trace to be added.
841+
Only valid if figure was created using
842+
`plotly.tools.make_subplots`
843+
col : int or None (default)
844+
Subplot col index (starting from 1) for the trace to be added.
845+
Only valid if figure was created using
846+
`plotly.tools.make_subplots`
847+
848+
Returns
849+
-------
850+
BaseTraceType
851+
The newly added trace
852+
853+
Examples
854+
--------
855+
>>> from plotly import tools
856+
>>> import plotly.graph_objs as go
857+
858+
Add two Scatter traces to a figure
859+
>>> fig = go.Figure()
860+
>>> fig.add_trace(go.Scatter(x=[1,2,3], y=[2,1,2]))
861+
>>> fig.add_trace(go.Scatter(x=[1,2,3], y=[2,1,2]))
862+
863+
864+
Add two Scatter traces to vertically stacked subplots
865+
>>> fig = tools.make_subplots(rows=2)
866+
This is the format of your plot grid:
867+
[ (1,1) x1,y1 ]
868+
[ (2,1) x2,y2 ]
869+
870+
>>> fig.add_trace(go.Scatter(x=[1,2,3], y=[2,1,2]), row=1, col=1)
871+
>>> fig.add_trace(go.Scatter(x=[1,2,3], y=[2,1,2]), row=2, col=1)
872+
"""
873+
# Validate row/col
874+
if row is not None and not isinstance(row, int):
875+
pass
876+
877+
if col is not None and not isinstance(col, int):
878+
pass
879+
880+
# Make sure we have both row and col or neither
881+
if row is not None and col is None:
882+
raise ValueError(
883+
'Received row parameter but not col.\n'
884+
'row and col must be specified together')
885+
elif col is not None and row is None:
886+
raise ValueError(
887+
'Received col parameter but not row.\n'
888+
'row and col must be specified together')
889+
890+
return self.add_traces(data=[trace],
891+
rows=[row] if row is not None else None,
892+
cols=[col] if col is not None else None
893+
)[0]
894+
895+
def add_traces(self, data, rows=None, cols=None):
896+
"""
897+
Add traces to the figure
898+
899+
Parameters
900+
----------
901+
data : list[BaseTraceType or dict]
902+
A list of trace specifications to be added.
802903
Trace specifications may be either:
803904
804905
- Instances of trace classes from the plotly.graph_objs
@@ -810,23 +911,70 @@ def add_traces(self, data):
810911
property then 'scatter' is assumed.
811912
- All remaining properties are passed to the constructor
812913
of the specified trace type.
914+
915+
rows : None or list[int] (default None)
916+
List of subplot row indexes (starting from 1) for the traces to be
917+
added. Only valid if figure was created using
918+
`plotly.tools.make_subplots`
919+
cols : None or list[int] (default None)
920+
List of subplot column indexes (starting from 1) for the traces
921+
to be added. Only valid if figure was created using
922+
`plotly.tools.make_subplots`
923+
813924
Returns
814925
-------
815926
tuple[BaseTraceType]
816-
Tuple of the newly added trace(s)
927+
Tuple of the newly added traces
928+
929+
Examples
930+
--------
931+
>>> from plotly import tools
932+
>>> import plotly.graph_objs as go
933+
934+
Add two Scatter traces to a figure
935+
>>> fig = go.Figure()
936+
>>> fig.add_traces([go.Scatter(x=[1,2,3], y=[2,1,2]),
937+
... go.Scatter(x=[1,2,3], y=[2,1,2])])
938+
939+
Add two Scatter traces to vertically stacked subplots
940+
>>> fig = tools.make_subplots(rows=2)
941+
This is the format of your plot grid:
942+
[ (1,1) x1,y1 ]
943+
[ (2,1) x2,y2 ]
944+
945+
>>> fig.add_traces([go.Scatter(x=[1,2,3], y=[2,1,2]),
946+
... go.Scatter(x=[1,2,3], y=[2,1,2])],
947+
... rows=[1, 2], cols=[1, 1])
817948
"""
818949

819950
if self._in_batch_mode:
820951
self._batch_layout_edits.clear()
821952
self._batch_trace_edits.clear()
822953
raise ValueError('Traces may not be added in a batch context')
823954

824-
if not isinstance(data, (list, tuple)):
825-
data = [data]
826-
827-
# Validate
955+
# Validate traces
828956
data = self._data_validator.validate_coerce(data)
829957

958+
# Validate rows / cols
959+
n = len(data)
960+
BaseFigure._validate_rows_cols('rows', n, rows)
961+
BaseFigure._validate_rows_cols('cols', n, cols)
962+
963+
# Make sure we have both rows and cols or neither
964+
if rows is not None and cols is None:
965+
raise ValueError(
966+
'Received rows parameter but not cols.\n'
967+
'rows and cols must be specified together')
968+
elif cols is not None and rows is None:
969+
raise ValueError(
970+
'Received cols parameter but not rows.\n'
971+
'rows and cols must be specified together')
972+
973+
# Apply rows / cols
974+
if rows is not None:
975+
for trace, row, col in zip(data, rows, cols):
976+
self._set_trace_grid_position(trace, row, col)
977+
830978
# Make deep copy of trace data (Optimize later if needed)
831979
new_traces_data = [deepcopy(trace._props) for trace in data]
832980

@@ -877,10 +1025,6 @@ def append_trace(self, trace, row, col):
8771025
col: int
8781026
Subplot column index (see Figure.print_grid)
8791027
880-
:param (dict) trace: The data trace to be bound.
881-
:param (int) row: Subplot row index (see Figure.print_grid).
882-
:param (int) col: Subplot column index (see Figure.print_grid).
883-
8841028
Examples
8851029
--------
8861030
>>> from plotly import tools
@@ -894,6 +1038,14 @@ def append_trace(self, trace, row, col):
8941038
>>> fig.append_trace(go.Scatter(x=[1,2,3], y=[2,1,2]), row=1, col=1)
8951039
>>> fig.append_trace(go.Scatter(x=[1,2,3], y=[2,1,2]), row=2, col=1)
8961040
"""
1041+
warnings.warn("""\
1042+
The append_trace method is deprecated and will be removed in a future version.
1043+
Please use the add_trace method with the row and col parameters.
1044+
""", DeprecationWarning)
1045+
1046+
self.add_trace(trace=trace, row=row, col=col)
1047+
1048+
def _set_trace_grid_position(self, trace, row, col):
8971049
try:
8981050
grid_ref = self._grid_ref
8991051
except AttributeError:
@@ -931,8 +1083,6 @@ def append_trace(self, trace, row, col):
9311083
trace['xaxis'] = ref[0]
9321084
trace['yaxis'] = ref[1]
9331085

934-
self.add_traces([trace])
935-
9361086
# Child property operations
9371087
# -------------------------
9381088
def _get_child_props(self, child):

0 commit comments

Comments
 (0)