From 5a58dd9a0e45ed10febd50a836fb9e84333b8ca1 Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Fri, 6 May 2022 16:22:04 +0200 Subject: [PATCH 01/10] initial feature scatter --- src/napari_matplotlib/base.py | 10 +- src/napari_matplotlib/scatter.py | 179 ++++++++++++++++++++++++++++--- 2 files changed, 173 insertions(+), 16 deletions(-) diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index 6bfbd093..f828b8dc 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -79,13 +79,14 @@ def setup_callbacks(self) -> None: # z-step changed in viewer self.viewer.dims.events.current_step.connect(self._draw) # Layer selection changed in viewer - self.viewer.layers.selection.events.active.connect(self.update_layers) + self.viewer.layers.selection.events.changed.connect(self.update_layers) def update_layers(self, event: napari.utils.events.Event) -> None: """ Update the currently selected layers and re-draw. """ self.layers = list(self.viewer.layers.selection) + self._on_update_layers() self._draw() def _draw(self) -> None: @@ -95,6 +96,7 @@ def _draw(self) -> None: """ self.clear() if self.n_selected_layers != self.n_layers_input: + self.canvas.draw() return self.draw() self.canvas.draw() @@ -112,3 +114,9 @@ def draw(self) -> None: This is a no-op, and is intended for derived classes to override. """ + + def _on_update_layers(self) -> None: + """This function is called when self.layers is updated via self.update_layer() + + This is a no-op, and is intended for derived classes to override. + """ diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index c3b12742..75f8bff6 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -1,39 +1,188 @@ +from typing import List, Tuple, Union + import matplotlib.colors as mcolor import napari +import numpy as np +from magicgui import magicgui from .base import NapariMPLWidget __all__ = ["ScatterWidget"] -class ScatterWidget(NapariMPLWidget): - """ - Widget to display scatter plot of two similarly shaped layers. +class ScatterBaseWidget(NapariMPLWidget): + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + marker_alpha: float = 0.5, + histogram_for_large_data: bool = True, + ): + super().__init__(napari_viewer) - If there are more than 500 data points, a 2D histogram is displayed instead - of a scatter plot, to avoid too many scatter points. - """ + # flag set to True if histogram should be used + # for plotting large points + self.histogram_for_large_data = histogram_for_large_data - n_layers_input = 2 + # set plotting visualization attributes + self._marker_alpha = 0.5 - def __init__(self, napari_viewer: napari.viewer.Viewer): - super().__init__(napari_viewer) self.axes = self.canvas.figure.subplots() self.update_layers(None) + @property + def marker_alpha(self) -> float: + """Alpha (opacity) for the scatter markers""" + return self._marker_alpha + + @marker_alpha.setter + def marker_alpha(self, alpha: float): + self._marker_alpha = alpha + self._draw() + + def clear(self) -> None: + self.axes.clear() + def draw(self) -> None: """ Clear the axes and scatter the currently selected layers. """ - data = [layer.data[self.current_z] for layer in self.layers] - if data[0].size < 500: - self.axes.scatter(data[0], data[1], alpha=0.5) - else: + data, x_axis_name, y_axis_name = self._get_data() + + if len(data) == 0: + # don't plot if there isn't data + return + + if self.histogram_for_large_data and (data[0].size > 500): self.axes.hist2d( data[0].ravel(), data[1].ravel(), bins=100, norm=mcolor.LogNorm(), ) - self.axes.set_xlabel(self.layers[0].name) - self.axes.set_ylabel(self.layers[1].name) + else: + self.axes.scatter(data[0], data[1], alpha=self.marker_alpha) + + self.axes.set_xlabel(x_axis_name) + self.axes.set_ylabel(y_axis_name) + + def _get_data(self) -> Tuple[np.ndarray, str, str]: + raise NotImplementedError + + +class ScatterWidget(ScatterBaseWidget): + """ + Widget to display scatter plot of two similarly shaped layers. + + If there are more than 500 data points, a 2D histogram is displayed instead + of a scatter plot, to avoid too many scatter points. + """ + + n_layers_input = 2 + + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + marker_alpha: float = 0.5, + histogram_for_large_data: bool = True, + ): + super().__init__( + napari_viewer, + marker_alpha=marker_alpha, + histogram_for_large_data=histogram_for_large_data, + ) + + def _get_data(self) -> Tuple[np.ndarray, str, str]: + data = [layer.data[self.current_z] for layer in self.layers] + x_axis_name = self.layers[0].name + y_axis_name = self.layers[1].name + + return data, x_axis_name, y_axis_name + + +class FeaturesScatterWidget(ScatterBaseWidget): + n_layers_input = 1 + + def __init__( + self, + napari_viewer: napari.viewer.Viewer, + marker_alpha: float = 0.5, + histogram_for_large_data: bool = True, + key_selection_gui: bool = True, + ): + self._key_selection_widget = None + super().__init__( + napari_viewer, + marker_alpha=marker_alpha, + histogram_for_large_data=histogram_for_large_data, + ) + + if key_selection_gui is True: + self._key_selection_widget = magicgui( + self._set_axis_keys, + x_axis_key={"choices": self._get_valid_axis_keys}, + y_axis_key={"choices": self._get_valid_axis_keys}, + call_button="plot", + ) + self.layout().addWidget(self._key_selection_widget.native) + + @property + def x_axis_key(self) -> Union[None, str]: + """Key to access x axis data from the FeaturesTable""" + return self._x_axis_key + + @x_axis_key.setter + def x_axis_key(self, key: Union[None, str]): + self._x_axis_key = key + self._draw() + + @property + def y_axis_key(self) -> Union[None, str]: + """Key to access y axis data from the FeaturesTable""" + return self._y_axis_key + + @y_axis_key.setter + def y_axis_key(self, key: Union[None, str]): + self._y_axis_key = key + self._draw() + + def _set_axis_keys(self, x_axis_key: str, y_axis_key: str): + """Set both axis keys and then redraw the plot""" + self._x_axis_key = x_axis_key + self._y_axis_key = y_axis_key + self._draw() + + def _get_valid_axis_keys(self, combo_widget=None) -> List[str]: + if len(self.layers) == 0: + return [] + else: + return self.layers[0].features.keys() + + def _get_data(self) -> Tuple[np.ndarray, str, str]: + feature_table = self.layers[0].features + + if ( + (len(feature_table) == 0) + or (self.x_axis_key is None) + or (self.y_axis_key is None) + ): + return np.array([]), "", "" + + data_x = feature_table[self.x_axis_key] + data_y = feature_table[self.y_axis_key] + data = np.stack((data_x, data_y)) + + x_axis_name = self.x_axis_key.replace("_", " ") + y_axis_name = self.y_axis_key.replace("_", " ") + + return data, x_axis_name, y_axis_name + + def _on_update_layers(self) -> None: + """This is called when the layer selection changes + by self.update_layers(). + """ + if self._key_selection_widget is not None: + self._key_selection_widget.reset_choices() + + # reset the axis keys + self._x_axis_key = None + self._y_axis_key = None From ef1029b1c15273b6decb26ff44b437e5ae0ff32e Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Fri, 6 May 2022 16:51:24 +0200 Subject: [PATCH 02/10] update docstrings --- src/napari_matplotlib/scatter.py | 61 +++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 5 deletions(-) diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 75f8bff6..7b774960 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -65,7 +65,20 @@ def draw(self) -> None: self.axes.set_xlabel(x_axis_name) self.axes.set_ylabel(y_axis_name) - def _get_data(self) -> Tuple[np.ndarray, str, str]: + def _get_data(self) -> Tuple[List[np.ndarray], str, str]: + """Get the plot data. + + This must be implemented on the subclass. + + Returns + ------- + data : np.ndarray + The list containing the scatter plot data. + x_axis_name : str + The title to display on the x axis + y_axis_name: str + The title to display on the y axis + """ raise NotImplementedError @@ -91,7 +104,18 @@ def __init__( histogram_for_large_data=histogram_for_large_data, ) - def _get_data(self) -> Tuple[np.ndarray, str, str]: + def _get_data(self) -> Tuple[List[np.ndarray], str, str]: + """Get the plot data. + + Returns + ------- + data : List[np.ndarray] + List contains the in view slice of X and Y axis images. + x_axis_name : str + The title to display on the x axis + y_axis_name: str + The title to display on the y axis + """ data = [layer.data[self.current_z] for layer in self.layers] x_axis_name = self.layers[0].name y_axis_name = self.layers[1].name @@ -152,12 +176,39 @@ def _set_axis_keys(self, x_axis_key: str, y_axis_key: str): self._draw() def _get_valid_axis_keys(self, combo_widget=None) -> List[str]: - if len(self.layers) == 0: + """Get the valid axis keys from the layer FeatureTable. + + Returns + ------- + axis_keys : List[str] + The valid axis keys in the FeatureTable. If the table is empty + or there isn't a table, returns an empty list. + """ + if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")): return [] else: return self.layers[0].features.keys() - def _get_data(self) -> Tuple[np.ndarray, str, str]: + def _get_data(self) -> Tuple[List[np.ndarray], str, str]: + """Get the plot data. + + Returns + ------- + data : List[np.ndarray] + List contains X and Y columns from the FeatureTable. Returns + an empty array if nothing to plot. + x_axis_name : str + The title to display on the x axis. Returns + an empty string if nothing to plot. + y_axis_name: str + The title to display on the y axis. Returns + an empty string if nothing to plot. + """ + if not hasattr(self.layers[0], "features"): + # if the selected layer doesn't have a featuretable, + # skip draw + return np.array([]), "", "" + feature_table = self.layers[0].features if ( @@ -169,7 +220,7 @@ def _get_data(self) -> Tuple[np.ndarray, str, str]: data_x = feature_table[self.x_axis_key] data_y = feature_table[self.y_axis_key] - data = np.stack((data_x, data_y)) + data = [data_x, data_y] x_axis_name = self.x_axis_key.replace("_", " ") y_axis_name = self.y_axis_key.replace("_", " ") From 91928e10ed4f535c34123e61f9ab19a2e2b38224 Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Fri, 6 May 2022 16:51:47 +0200 Subject: [PATCH 03/10] expose scatter plots --- src/napari_matplotlib/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/napari_matplotlib/__init__.py b/src/napari_matplotlib/__init__.py index 7e8ccf69..5b112cda 100644 --- a/src/napari_matplotlib/__init__.py +++ b/src/napari_matplotlib/__init__.py @@ -5,4 +5,4 @@ from .histogram import * # NoQA -from .scatter import * # NoQA +from .scatter import FeaturesScatterWidget, ScatterWidget # NoQA From 22fd3ee6308894bf8d42212e031a0ae3ca7a222d Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Fri, 6 May 2022 17:28:00 +0200 Subject: [PATCH 04/10] add tests --- src/napari_matplotlib/tests/test_scatter.py | 92 ++++++++++++++++++++- 1 file changed, 90 insertions(+), 2 deletions(-) diff --git a/src/napari_matplotlib/tests/test_scatter.py b/src/napari_matplotlib/tests/test_scatter.py index 75a6fda6..8103968e 100644 --- a/src/napari_matplotlib/tests/test_scatter.py +++ b/src/napari_matplotlib/tests/test_scatter.py @@ -1,11 +1,99 @@ import numpy as np -from napari_matplotlib import ScatterWidget +from napari_matplotlib import FeaturesScatterWidget, ScatterWidget def test_scatter(make_napari_viewer): - # Smoke test adding a histogram widget + # Smoke test adding a scatter widget viewer = make_napari_viewer() viewer.add_image(np.random.random((100, 100))) viewer.add_image(np.random.random((100, 100))) ScatterWidget(viewer) + + +def test_features_scatter_widget(make_napari_viewer): + # Smoke test adding a features scatter widget + viewer = make_napari_viewer() + viewer.add_image(np.random.random((100, 100))) + viewer.add_labels(np.random.randint(0, 5, (100, 100))) + FeaturesScatterWidget(viewer) + + +def make_labels_layer_with_features(): + label_image = np.zeros((100, 100), dtype=np.uint16) + for label_value, start_index in enumerate([10, 30, 50], start=1): + end_index = start_index + 10 + label_image[start_index:end_index, start_index:end_index] = label_value + feature_table = { + "index": [1, 2, 3], + "feature_0": np.random.random((3,)), + "feature_1": np.random.random((3,)), + "feature_2": np.random.random((3,)), + } + return label_image, feature_table + + +def test_features_scatter_get_data(make_napari_viewer): + """test the get data method""" + # make the label image + label_image, feature_table = make_labels_layer_with_features() + + viewer = make_napari_viewer() + labels_layer = viewer.add_labels(label_image, features=feature_table) + scatter_widget = FeaturesScatterWidget(viewer) + + # select the labels layer + viewer.layers.selection = [labels_layer] + + x_column = "feature_0" + scatter_widget.x_axis_key = x_column + y_column = "feature_2" + scatter_widget.y_axis_key = y_column + + data, x_axis_name, y_axis_name = scatter_widget._get_data() + np.testing.assert_allclose( + data, np.stack((feature_table[x_column], feature_table[y_column])) + ) + assert x_axis_name == x_column.replace("_", " ") + assert y_axis_name == y_column.replace("_", " ") + + +def test_get_valid_axis_keys(make_napari_viewer): + """test the values returned from + FeaturesScatterWidget._get_valid_keys() when there + are valid keys. + """ + # make the label image + label_image, feature_table = make_labels_layer_with_features() + + viewer = make_napari_viewer() + labels_layer = viewer.add_labels(label_image, features=feature_table) + scatter_widget = FeaturesScatterWidget(viewer) + + viewer.layers.selection = [labels_layer] + valid_keys = scatter_widget._get_valid_axis_keys() + assert set(valid_keys) == set(feature_table.keys()) + + +def test_get_valid_axis_keys_no_valid_keys(make_napari_viewer): + """test the values returned from + FeaturesScatterWidget._get_valid_keys() when there + are not valid keys. + """ + # make the label image + label_image, _ = make_labels_layer_with_features() + + viewer = make_napari_viewer() + labels_layer = viewer.add_labels(label_image) + image_layer = viewer.add_image(np.random.random((100, 100))) + scatter_widget = FeaturesScatterWidget(viewer) + + # no features in a label image + viewer.layers.selection = [labels_layer] + valid_keys = scatter_widget._get_valid_axis_keys() + assert set(valid_keys) == set() + + # image layer doesn't have features + viewer.layers.selection = [image_layer] + valid_keys = scatter_widget._get_valid_axis_keys() + assert set(valid_keys) == set() From e79eeafa3fae055bfe726b703d26bf761f72608a Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Fri, 6 May 2022 17:35:15 +0200 Subject: [PATCH 05/10] add example --- examples/features_scatter.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 examples/features_scatter.py diff --git a/examples/features_scatter.py b/examples/features_scatter.py new file mode 100644 index 00000000..0894d0a4 --- /dev/null +++ b/examples/features_scatter.py @@ -0,0 +1,35 @@ +import napari +import numpy as np +from skimage.measure import regionprops_table + +from napari_matplotlib.scatter import FeaturesScatterWidget + +# make a test label image +label_image = np.zeros((100, 100), dtype=np.uint16) + +label_image[10:20, 10:20] = 1 +label_image[50:70, 50:70] = 2 + +feature_table_1 = regionprops_table(label_image, properties=("label",)) +feature_table_1["index"] = feature_table_1["label"] + +# make the points data +n_points = 100 +points_data = 100 * np.random.random((100, 2)) +points_features = { + "feature_0": np.random.random((n_points,)), + "feature_1": np.random.random((n_points,)), + "feature_2": np.random.random((n_points,)), +} + +# create the viewer +viewer = napari.Viewer() +viewer.add_labels(label_image, features=feature_table_1) +viewer.add_points(points_data, features=points_features) + +# make the widget +features_widget = FeaturesScatterWidget(viewer, histogram_for_large_data=False) +viewer.window.add_dock_widget(features_widget) + +if __name__ == "__main__": + napari.run() From 9d3a0420c0d3673b84cf925b6453d51672871dd1 Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Sat, 7 May 2022 05:00:34 -0700 Subject: [PATCH 06/10] fix imports --- src/napari_matplotlib/__init__.py | 2 +- src/napari_matplotlib/scatter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/napari_matplotlib/__init__.py b/src/napari_matplotlib/__init__.py index 5b112cda..7e8ccf69 100644 --- a/src/napari_matplotlib/__init__.py +++ b/src/napari_matplotlib/__init__.py @@ -5,4 +5,4 @@ from .histogram import * # NoQA -from .scatter import FeaturesScatterWidget, ScatterWidget # NoQA +from .scatter import * # NoQA diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 7b774960..5a504218 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -7,7 +7,7 @@ from .base import NapariMPLWidget -__all__ = ["ScatterWidget"] +__all__ = ["ScatterWidget", "FeaturesScatterWidget"] class ScatterBaseWidget(NapariMPLWidget): From fc03d7834d09dfe114f3af9facb580dc58076e8e Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Sat, 7 May 2022 05:40:43 -0700 Subject: [PATCH 07/10] update example --- examples/features_scatter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/features_scatter.py b/examples/features_scatter.py index 0894d0a4..28c8d006 100644 --- a/examples/features_scatter.py +++ b/examples/features_scatter.py @@ -10,7 +10,9 @@ label_image[10:20, 10:20] = 1 label_image[50:70, 50:70] = 2 -feature_table_1 = regionprops_table(label_image, properties=("label",)) +feature_table_1 = regionprops_table( + label_image, properties=("label", "area", "perimeter") +) feature_table_1["index"] = feature_table_1["label"] # make the points data From 0340c54d41d1f3cb8307e32af15d1308714d5a48 Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Sun, 8 May 2022 05:42:33 -0700 Subject: [PATCH 08/10] Apply suggestions from code review Co-authored-by: David Stansby --- src/napari_matplotlib/base.py | 2 +- src/napari_matplotlib/scatter.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py index f828b8dc..b49e368c 100644 --- a/src/napari_matplotlib/base.py +++ b/src/napari_matplotlib/base.py @@ -116,7 +116,7 @@ def draw(self) -> None: """ def _on_update_layers(self) -> None: - """This function is called when self.layers is updated via self.update_layer() + """This function is called when self.layers is updated via self.update_layers() This is a no-op, and is intended for derived classes to override. """ diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 5a504218..7035326a 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -75,16 +75,16 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]: data : np.ndarray The list containing the scatter plot data. x_axis_name : str - The title to display on the x axis + The label to display on the x axis y_axis_name: str - The title to display on the y axis + The label to display on the y axis """ raise NotImplementedError class ScatterWidget(ScatterBaseWidget): """ - Widget to display scatter plot of two similarly shaped layers. + Widget to display scatter plot of two similarly shaped image layers. If there are more than 500 data points, a 2D histogram is displayed instead of a scatter plot, to avoid too many scatter points. From 9f4cbb95dba366c766a75df262b49cd98bd9e16e Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Sun, 8 May 2022 05:55:35 -0700 Subject: [PATCH 09/10] expose FeaturesScatterWidget as plugin --- examples/features_scatter.py | 7 +++--- src/napari_matplotlib/napari.yaml | 9 ++++++- src/napari_matplotlib/scatter.py | 42 ++++++++++--------------------- 3 files changed, 24 insertions(+), 34 deletions(-) diff --git a/examples/features_scatter.py b/examples/features_scatter.py index 28c8d006..ac8580d7 100644 --- a/examples/features_scatter.py +++ b/examples/features_scatter.py @@ -2,8 +2,6 @@ import numpy as np from skimage.measure import regionprops_table -from napari_matplotlib.scatter import FeaturesScatterWidget - # make a test label image label_image = np.zeros((100, 100), dtype=np.uint16) @@ -30,8 +28,9 @@ viewer.add_points(points_data, features=points_features) # make the widget -features_widget = FeaturesScatterWidget(viewer, histogram_for_large_data=False) -viewer.window.add_dock_widget(features_widget) +viewer.window.add_plugin_dock_widget( + plugin_name="napari-matplotlib", widget_name="FeaturesScatter" +) if __name__ == "__main__": napari.run() diff --git a/src/napari_matplotlib/napari.yaml b/src/napari_matplotlib/napari.yaml index 3ff66090..3beba8a7 100644 --- a/src/napari_matplotlib/napari.yaml +++ b/src/napari_matplotlib/napari.yaml @@ -8,7 +8,11 @@ contributions: - id: napari-matplotlib.scatter python_name: napari_matplotlib:ScatterWidget - title: Make a scatter plot + title: Make a scatter plot of image intensities + + - id: napari-matplotlib.features_scatter + python_name: napari_matplotlib:FeaturesScatterWidget + title: Make a scatter plot of layer features widgets: - command: napari-matplotlib.histogram @@ -16,3 +20,6 @@ contributions: - command: napari-matplotlib.scatter display_name: Scatter + + - command: napari-matplotlib.features_scatter + display_name: FeaturesScatter diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 7035326a..3de2e88e 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -11,34 +11,26 @@ class ScatterBaseWidget(NapariMPLWidget): + # opacity value for the markers + _marker_alpha = 0.5 + + # flag set to True if histogram should be used + # for plotting large points + _histogram_for_large_data = True + + # if the number of points is greater than this value, + # the scatter is plotted as a 2dhist + _threshold_to_switch_to_histogram = 500 + def __init__( self, napari_viewer: napari.viewer.Viewer, - marker_alpha: float = 0.5, - histogram_for_large_data: bool = True, ): super().__init__(napari_viewer) - # flag set to True if histogram should be used - # for plotting large points - self.histogram_for_large_data = histogram_for_large_data - - # set plotting visualization attributes - self._marker_alpha = 0.5 - self.axes = self.canvas.figure.subplots() self.update_layers(None) - @property - def marker_alpha(self) -> float: - """Alpha (opacity) for the scatter markers""" - return self._marker_alpha - - @marker_alpha.setter - def marker_alpha(self, alpha: float): - self._marker_alpha = alpha - self._draw() - def clear(self) -> None: self.axes.clear() @@ -52,7 +44,7 @@ def draw(self) -> None: # don't plot if there isn't data return - if self.histogram_for_large_data and (data[0].size > 500): + if self._histogram_for_large_data and (data[0].size > 500): self.axes.hist2d( data[0].ravel(), data[1].ravel(), @@ -60,7 +52,7 @@ def draw(self) -> None: norm=mcolor.LogNorm(), ) else: - self.axes.scatter(data[0], data[1], alpha=self.marker_alpha) + self.axes.scatter(data[0], data[1], alpha=self._marker_alpha) self.axes.set_xlabel(x_axis_name) self.axes.set_ylabel(y_axis_name) @@ -95,13 +87,9 @@ class ScatterWidget(ScatterBaseWidget): def __init__( self, napari_viewer: napari.viewer.Viewer, - marker_alpha: float = 0.5, - histogram_for_large_data: bool = True, ): super().__init__( napari_viewer, - marker_alpha=marker_alpha, - histogram_for_large_data=histogram_for_large_data, ) def _get_data(self) -> Tuple[List[np.ndarray], str, str]: @@ -129,15 +117,11 @@ class FeaturesScatterWidget(ScatterBaseWidget): def __init__( self, napari_viewer: napari.viewer.Viewer, - marker_alpha: float = 0.5, - histogram_for_large_data: bool = True, key_selection_gui: bool = True, ): self._key_selection_widget = None super().__init__( napari_viewer, - marker_alpha=marker_alpha, - histogram_for_large_data=histogram_for_large_data, ) if key_selection_gui is True: From e1941b0a35c4dc8cedf476d98d3a225371ca8865 Mon Sep 17 00:00:00 2001 From: Kevin Yamauchi Date: Sun, 8 May 2022 06:07:40 -0700 Subject: [PATCH 10/10] add threshold for histogram switch --- src/napari_matplotlib/scatter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 3de2e88e..324e9126 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -44,7 +44,9 @@ def draw(self) -> None: # don't plot if there isn't data return - if self._histogram_for_large_data and (data[0].size > 500): + if self._histogram_for_large_data and ( + data[0].size > self._threshold_to_switch_to_histogram + ): self.axes.hist2d( data[0].ravel(), data[1].ravel(),