Skip to content

Commit b409770

Browse files
Merge pull request #2921 from plotly/annotated_heatmap_colors
honor zmin/zmid/zmax in annotated_heatmap font colors
2 parents 0458ac9 + f805e3c commit b409770

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

packages/python/plotly/plotly/figure_factory/_annotated_heatmap.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,23 @@ def __init__(
189189
self.reversescale = reversescale
190190
self.font_colors = font_colors
191191

192+
if np and isinstance(self.z, np.ndarray):
193+
self.zmin = np.amin(self.z)
194+
self.zmax = np.amax(self.z)
195+
else:
196+
self.zmin = min([v for row in self.z for v in row])
197+
self.zmax = max([v for row in self.z for v in row])
198+
199+
if kwargs.get("zmin", None) is not None:
200+
self.zmin = kwargs["zmin"]
201+
if kwargs.get("zmax", None) is not None:
202+
self.zmax = kwargs["zmax"]
203+
204+
self.zmid = (self.zmax + self.zmin) / 2
205+
206+
if kwargs.get("zmid", None) is not None:
207+
self.zmid = kwargs["zmid"]
208+
192209
def get_text_color(self):
193210
"""
194211
Get font color for annotations.
@@ -264,21 +281,6 @@ def get_text_color(self):
264281
max_text_color = black
265282
return min_text_color, max_text_color
266283

267-
def get_z_mid(self):
268-
"""
269-
Get the mid value of z matrix
270-
271-
:rtype (float) z_avg: average val from z matrix
272-
"""
273-
if np and isinstance(self.z, np.ndarray):
274-
z_min = np.amin(self.z)
275-
z_max = np.amax(self.z)
276-
else:
277-
z_min = min([v for row in self.z for v in row])
278-
z_max = max([v for row in self.z for v in row])
279-
z_mid = (z_max + z_min) / 2
280-
return z_mid
281-
282284
def make_annotations(self):
283285
"""
284286
Get annotations for each cell of the heatmap with graph_objs.Annotation
@@ -287,11 +289,10 @@ def make_annotations(self):
287289
the heatmap
288290
"""
289291
min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self)
290-
z_mid = _AnnotatedHeatmap.get_z_mid(self)
291292
annotations = []
292293
for n, row in enumerate(self.z):
293294
for m, val in enumerate(row):
294-
font_color = min_text_color if val < z_mid else max_text_color
295+
font_color = min_text_color if val < self.zmid else max_text_color
295296
annotations.append(
296297
graph_objs.layout.Annotation(
297298
text=str(self.annotation_text[n][m]),

0 commit comments

Comments
 (0)