From 367973194a0324cf61159eddefc76f47aee00139 Mon Sep 17 00:00:00 2001 From: Sam Cunliffe Date: Thu, 8 Jun 2023 09:53:09 +0100 Subject: [PATCH 1/4] Test first. A TDD test to check we can set the theme from a user-defined stylesheet. For this, just use Solarized_Light2. --- src/napari_matplotlib/tests/test_theme.py | 38 ++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/napari_matplotlib/tests/test_theme.py b/src/napari_matplotlib/tests/test_theme.py index dfeb5b5f..1e4a2073 100644 --- a/src/napari_matplotlib/tests/test_theme.py +++ b/src/napari_matplotlib/tests/test_theme.py @@ -1,8 +1,13 @@ +import shutil +from pathlib import Path + +import matplotlib import napari import numpy as np import pytest +from matplotlib.colors import to_rgba -from napari_matplotlib import ScatterWidget +from napari_matplotlib import HistogramWidget, ScatterWidget from napari_matplotlib.base import NapariMPLWidget @@ -88,3 +93,34 @@ def test_titles_respect_theme( assert ax.xaxis.label.get_color() == expected_text_colour assert ax.yaxis.label.get_color() == expected_text_colour + + +def find_mpl_stylesheet(name: str) -> Path: + """Find the built-in matplotlib stylesheet.""" + return Path(matplotlib.__path__[0]) / f"mpl-data/stylelib/{name}.mplstyle" + + +def test_stylesheet_in_cwd(tmpdir, make_napari_viewer, image_data): + """ + Test that a stylesheet in the current directory is given precidence. + + Do this by copying over a stylesheet from matplotlib's built in styles, + naming it correctly, and checking the colours are as expected. + """ + with tmpdir.as_cwd(): + # Copy Solarize_Light2 to current dir as if it was a user-overriden stylesheet. + shutil.copy(find_mpl_stylesheet("Solarize_Light2"), "./user.mplstyle") + viewer = make_napari_viewer() + viewer.add_image(image_data[0], **image_data[1]) + widget = HistogramWidget(viewer) + ax = widget.figure.gca() + + # The axes should have a light brownish grey background: + assert ax.get_facecolor() == to_rgba("#eee8d5") + assert ax.patch.get_facecolor() == to_rgba("#eee8d5") + + # The figure background and axis gridlines are light yellow: + assert widget.figure.patch.get_facecolor() == to_rgba("#fdf6e3") + for gridline in ax.get_xgridlines() + ax.get_ygridlines(): + assert gridline.get_visible() is True + assert gridline.get_color() == "#fdf6e3" From be3ed986bcfd458f6f06919f619aec7363ab5a67 Mon Sep 17 00:00:00 2001 From: Sam Cunliffe Date: Thu, 8 Jun 2023 09:57:25 +0100 Subject: [PATCH 2/4] Users can override the napari-theme-based style ... with their own custom stylesheet. At the moment this must be called 'user.mplstyle' in the cwd. --- src/napari_matplotlib/base.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index fda2d2d5..9e0d50e8 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import List, Optional, Tuple +import matplotlib.style import napari from matplotlib.axes import Axes from matplotlib.backends.backend_qtagg import ( @@ -41,9 +42,11 @@ def __init__( super().__init__(parent=parent) self.viewer = napari_viewer + has_mpl_stylesheet = self._apply_user_stylesheet_if_present() self.canvas = FigureCanvas() - self.canvas.figure.patch.set_facecolor("none") + if not has_mpl_stylesheet: + self.canvas.figure.patch.set_facecolor("none") self.canvas.figure.set_layout_engine("constrained") self.toolbar = NapariNavigationToolbar( self.canvas, parent=self @@ -70,10 +73,16 @@ def add_single_axes(self) -> None: The Axes is saved on the ``.axes`` attribute for later access. """ self.axes = self.figure.subplots() - self.apply_napari_colorscheme(self.axes) + self.apply_style(self.axes) + + def apply_style(self, ax: Axes) -> None: + """ + Use the user-supplied stylesheet if present, otherwise apply the + napari-compatible colorscheme (theme-dependent) to an Axes. + """ + if self._apply_user_stylesheet_if_present(): + return - def apply_napari_colorscheme(self, ax: Axes) -> None: - """Apply napari-compatible colorscheme to an Axes.""" # get the foreground colours from current theme theme = napari.utils.theme.get_theme(self.viewer.theme, as_dict=False) fg_colour = theme.foreground.as_hex() # fg is a muted contrast to bg @@ -93,6 +102,20 @@ def apply_napari_colorscheme(self, ax: Axes) -> None: ax.tick_params(axis="x", colors=text_colour) ax.tick_params(axis="y", colors=text_colour) + def _apply_user_stylesheet_if_present(self) -> bool: + """ + Apply the user-supplied stylesheet if present. + + Returns + ------- + True if the stylesheet was present and applied. + False otherwise. + """ + if (Path.cwd() / "user.mplstyle").exists(): + matplotlib.style.use("./user.mplstyle") + return True + return False + def _on_theme_change(self) -> None: """Update MPL toolbar and axis styling when `napari.Viewer.theme` is changed. @@ -101,7 +124,7 @@ def _on_theme_change(self) -> None: """ self._replace_toolbar_icons() if self.figure.gca(): - self.apply_napari_colorscheme(self.figure.gca()) + self.apply_style(self.figure.gca()) def _theme_has_light_bg(self) -> bool: """ @@ -245,7 +268,7 @@ def _draw(self) -> None: isinstance(layer, self.input_layer_types) for layer in self.layers ): self.draw() - self.apply_napari_colorscheme(self.figure.gca()) + self.apply_style(self.figure.gca()) self.canvas.draw() def clear(self) -> None: From dd36b143b89b960d8cc58193b29f942ad2321514 Mon Sep 17 00:00:00 2001 From: Sam Cunliffe Date: Mon, 12 Jun 2023 14:19:11 +0100 Subject: [PATCH 3/4] Test to guard against re-introducing #64. --- src/napari_matplotlib/tests/test_theme.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/napari_matplotlib/tests/test_theme.py b/src/napari_matplotlib/tests/test_theme.py index 1e4a2073..68b46430 100644 --- a/src/napari_matplotlib/tests/test_theme.py +++ b/src/napari_matplotlib/tests/test_theme.py @@ -1,4 +1,5 @@ import shutil +from copy import deepcopy from pathlib import Path import matplotlib @@ -124,3 +125,22 @@ def test_stylesheet_in_cwd(tmpdir, make_napari_viewer, image_data): for gridline in ax.get_xgridlines() + ax.get_ygridlines(): assert gridline.get_visible() is True assert gridline.get_color() == "#fdf6e3" + + +@pytest.mark.mpl_image_compare +def test_theme_doesnt_leak(make_napari_viewer): + """Ensure that napari-matplotlib doesn't pollute the globally set style. + + A MWE to guard aganst issue matplotlib/#64. Should always reproduce a plot + with the default matplotlib style. + """ + import matplotlib.pyplot as plt + + # should not affect global style + viewer = make_napari_viewer() + HistogramWidget(viewer) + + np.random.seed(12345) + image = np.random.random((3, 3)) + plot = plt.imshow(image) + return deepcopy(plot) From 2de132e566850770f1d2e2721aad944fc2675d30 Mon Sep 17 00:00:00 2001 From: Sam Cunliffe Date: Thu, 15 Jun 2023 10:59:13 +0100 Subject: [PATCH 4/4] Tweak the side-effects test to actually test this cwd functionality. --- src/napari_matplotlib/tests/test_theme.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/napari_matplotlib/tests/test_theme.py b/src/napari_matplotlib/tests/test_theme.py index 145fd53b..988b6c35 100644 --- a/src/napari_matplotlib/tests/test_theme.py +++ b/src/napari_matplotlib/tests/test_theme.py @@ -96,7 +96,7 @@ def test_titles_respect_theme( @pytest.mark.mpl_image_compare -def test_no_theme_side_effects(make_napari_viewer): +def test_no_theme_side_effects(tmpdir, make_napari_viewer): """Ensure that napari-matplotlib doesn't pollute the globally set style. A MWE to guard aganst issue matplotlib/#64. Should always reproduce a plot @@ -107,9 +107,11 @@ def test_no_theme_side_effects(make_napari_viewer): np.random.seed(12345) # should not affect global matplotlib plot style - viewer = make_napari_viewer() - viewer.theme = "dark" - NapariMPLWidget(viewer) + with tmpdir.as_cwd(): + shutil.copy(find_mpl_stylesheet("Solarize_Light2"), "./user.mplstyle") + viewer = make_napari_viewer() + viewer.theme = "dark" + NapariMPLWidget(viewer) # some plotting unrelated to napari-matplotlib normal_dist = np.random.normal(size=1000)