diff --git a/baseline/test_feature_histogram2.png b/baseline/test_feature_histogram2.png new file mode 100644 index 00000000..b7bb19e0 Binary files /dev/null and b/baseline/test_feature_histogram2.png differ diff --git a/docs/changelog.rst b/docs/changelog.rst index 2304fecf..6f77e0c3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,7 +1,11 @@ Changelog ========= -1.0.3 +1.1.0 ----- +Additions +~~~~~~~~~ +- Added a widget to draw a histogram of features. + Changes ~~~~~~~ - The slice widget is now limited to slicing along the x/y dimensions. Support diff --git a/docs/user_guide.rst b/docs/user_guide.rst index 0872e540..fbd48db1 100644 --- a/docs/user_guide.rst +++ b/docs/user_guide.rst @@ -30,6 +30,7 @@ These widgets plot the data stored in the ``.features`` attribute of individual Currently available are: - 2D scatter plots of two features against each other. +- Histograms of individual features. To use these: diff --git a/src/napari_matplotlib/features.py b/src/napari_matplotlib/features.py new file mode 100644 index 00000000..34abf104 --- /dev/null +++ b/src/napari_matplotlib/features.py @@ -0,0 +1,9 @@ +from napari.layers import Labels, Points, Shapes, Tracks, Vectors + +FEATURES_LAYER_TYPES = ( + Labels, + Points, + Shapes, + Tracks, + Vectors, +) diff --git a/src/napari_matplotlib/histogram.py b/src/napari_matplotlib/histogram.py index ab098f38..66aa7acc 100644 --- a/src/napari_matplotlib/histogram.py +++ b/src/napari_matplotlib/histogram.py @@ -1,13 +1,15 @@ -from typing import Optional +from typing import Any, List, Optional, Tuple import napari import numpy as np -from qtpy.QtWidgets import QWidget +import numpy.typing as npt +from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget from .base import SingleAxesWidget +from .features import FEATURES_LAYER_TYPES from .util import Interval -__all__ = ["HistogramWidget"] +__all__ = ["HistogramWidget", "FeaturesHistogramWidget"] _COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"} @@ -61,3 +63,112 @@ def draw(self) -> None: self.axes.hist(data.ravel(), bins=bins, label=layer.name) self.axes.legend() + + +class FeaturesHistogramWidget(SingleAxesWidget): + """ + Display a histogram of selected feature attached to selected layer. + """ + + n_layers_input = Interval(1, 1) + # All layers that have a .features attributes + input_layer_types = FEATURES_LAYER_TYPES + + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + parent: Optional[QWidget] = None, + ): + super().__init__(napari_viewer, parent=parent) + + self.layout().addLayout(QVBoxLayout()) + self._key_selection_widget = QComboBox() + self.layout().addWidget(QLabel("Key:")) + self.layout().addWidget(self._key_selection_widget) + + self._key_selection_widget.currentTextChanged.connect( + self._set_axis_keys + ) + + self._update_layers(None) + + @property + def x_axis_key(self) -> Optional[str]: + """Key to access x axis data from the FeaturesTable""" + return self._x_axis_key + + @x_axis_key.setter + def x_axis_key(self, key: Optional[str]) -> None: + self._x_axis_key = key + self._draw() + + def _set_axis_keys(self, x_axis_key: str) -> None: + """Set both axis keys and then redraw the plot""" + self._x_axis_key = x_axis_key + self._draw() + + def _get_valid_axis_keys(self) -> List[str]: + """ + Get the valid axis keys from the layer FeatureTable. + + Returns + ------- + axis_keys : List[str] + The valid axis keys in the FeatureTable. If the table is empty + or there isn't a table, returns an empty list. + """ + if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")): + return [] + else: + return self.layers[0].features.keys() + + def _get_data(self) -> Tuple[Optional[npt.NDArray[Any]], str]: + """Get the plot data. + + Returns + ------- + data : List[np.ndarray] + List contains X and Y columns from the FeatureTable. Returns + an empty array if nothing to plot. + x_axis_name : str + The title to display on the x axis. Returns + an empty string if nothing to plot. + """ + if not hasattr(self.layers[0], "features"): + # if the selected layer doesn't have a featuretable, + # skip draw + return None, "" + + feature_table = self.layers[0].features + + if (len(feature_table) == 0) or (self.x_axis_key is None): + return None, "" + + data = feature_table[self.x_axis_key] + x_axis_name = self.x_axis_key.replace("_", " ") + + return data, x_axis_name + + def on_update_layers(self) -> None: + """ + Called when the layer selection changes by ``self.update_layers()``. + """ + # reset the axis keys + self._x_axis_key = None + + # Clear combobox + self._key_selection_widget.clear() + self._key_selection_widget.addItems(self._get_valid_axis_keys()) + + def draw(self) -> None: + """Clear the axes and histogram the currently selected layer/slice.""" + data, x_axis_name = self._get_data() + + if data is None: + return + + self.axes.hist(data, bins=50, edgecolor="white", linewidth=0.3) + + # set ax labels + self.axes.set_xlabel(x_axis_name) + self.axes.set_ylabel("Counts [#]") diff --git a/src/napari_matplotlib/napari.yaml b/src/napari_matplotlib/napari.yaml index b736592b..71af0ca6 100644 --- a/src/napari_matplotlib/napari.yaml +++ b/src/napari_matplotlib/napari.yaml @@ -14,6 +14,10 @@ contributions: python_name: napari_matplotlib:FeaturesScatterWidget title: Make a scatter plot of layer features + - id: napari-matplotlib.features_histogram + python_name: napari_matplotlib:FeaturesHistogramWidget + title: Plot feature histograms + - id: napari-matplotlib.slice python_name: napari_matplotlib:SliceWidget title: Plot a 1D slice @@ -28,5 +32,8 @@ contributions: - command: napari-matplotlib.features_scatter display_name: FeaturesScatter + - command: napari-matplotlib.features_histogram + display_name: FeaturesHistogram + - command: napari-matplotlib.slice display_name: 1D slice diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index db86c7f3..a4148bd2 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -5,6 +5,7 @@ from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget from .base import SingleAxesWidget +from .features import FEATURES_LAYER_TYPES from .util import Interval __all__ = ["ScatterBaseWidget", "ScatterWidget", "FeaturesScatterWidget"] @@ -94,13 +95,7 @@ class FeaturesScatterWidget(ScatterBaseWidget): n_layers_input = Interval(1, 1) # All layers that have a .features attributes - input_layer_types = ( - napari.layers.Labels, - napari.layers.Points, - napari.layers.Shapes, - napari.layers.Tracks, - napari.layers.Vectors, - ) + input_layer_types = FEATURES_LAYER_TYPES def __init__( self, diff --git a/src/napari_matplotlib/tests/baseline/test_feature_histogram.png b/src/napari_matplotlib/tests/baseline/test_feature_histogram.png new file mode 100644 index 00000000..1892af44 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_feature_histogram.png differ diff --git a/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png b/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png new file mode 100644 index 00000000..b7bb19e0 Binary files /dev/null and b/src/napari_matplotlib/tests/baseline/test_feature_histogram2.png differ diff --git a/src/napari_matplotlib/tests/test_histogram.py b/src/napari_matplotlib/tests/test_histogram.py index 4d170014..006c042f 100644 --- a/src/napari_matplotlib/tests/test_histogram.py +++ b/src/napari_matplotlib/tests/test_histogram.py @@ -1,8 +1,13 @@ from copy import deepcopy +import numpy as np import pytest -from napari_matplotlib import HistogramWidget +from napari_matplotlib import FeaturesHistogramWidget, HistogramWidget +from napari_matplotlib.tests.helpers import ( + assert_figures_equal, + assert_figures_not_equal, +) @pytest.mark.mpl_image_compare @@ -28,3 +33,90 @@ def test_histogram_3D(make_napari_viewer, brain_data): # Need to return a copy, as original figure is too eagerley garbage # collected by the widget return deepcopy(fig) + + +def test_feature_histogram(make_napari_viewer): + n_points = 1000 + random_points = np.random.random((n_points, 3)) * 10 + feature1 = np.random.random(n_points) + feature2 = np.random.normal(size=n_points) + + viewer = make_napari_viewer() + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points1", + ) + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points2", + ) + + widget = FeaturesHistogramWidget(viewer) + viewer.window.add_dock_widget(widget) + + # Check whether changing the selected key changes the plot + widget._set_axis_keys("feature1") + fig1 = deepcopy(widget.figure) + + widget._set_axis_keys("feature2") + assert_figures_not_equal(widget.figure, fig1) + + # check whether selecting a different layer produces the same plot + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[1]) + assert_figures_equal(widget.figure, fig1) + + +@pytest.mark.mpl_image_compare +def test_feature_histogram2(make_napari_viewer): + import numpy as np + + np.random.seed(0) + n_points = 1000 + random_points = np.random.random((n_points, 3)) * 10 + feature1 = np.random.random(n_points) + feature2 = np.random.normal(size=n_points) + + viewer = make_napari_viewer() + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points1", + ) + viewer.add_points( + random_points, + properties={"feature1": feature1, "feature2": feature2}, + name="points2", + ) + + widget = FeaturesHistogramWidget(viewer) + viewer.window.add_dock_widget(widget) + widget._set_axis_keys("feature1") + + fig = FeaturesHistogramWidget(viewer).figure + return deepcopy(fig) + + +def test_change_layer(make_napari_viewer, brain_data, astronaut_data): + viewer = make_napari_viewer() + widget = HistogramWidget(viewer) + + viewer.add_image(brain_data[0], **brain_data[1]) + viewer.add_image(astronaut_data[0], **astronaut_data[1]) + + # Select first layer + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[0]) + fig1 = deepcopy(widget.figure) + + # Re-selecting first layer should produce identical plot + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[0]) + assert_figures_equal(widget.figure, fig1) + + # Plotting the second layer should produce a different plot + viewer.layers.selection.clear() + viewer.layers.selection.add(viewer.layers[1]) + assert_figures_not_equal(widget.figure, fig1)