Skip to content

Commit 5ed2084

Browse files
authored
Merge pull request #39 from kevinyamauchi/add-features-scatter
Add Features scatter plot
2 parents 451dfac + 86abe7a commit 5ed2084

File tree

5 files changed

+346
-19
lines changed

5 files changed

+346
-19
lines changed

examples/features_scatter.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import napari
2+
import numpy as np
3+
from skimage.measure import regionprops_table
4+
5+
# make a test label image
6+
label_image = np.zeros((100, 100), dtype=np.uint16)
7+
8+
label_image[10:20, 10:20] = 1
9+
label_image[50:70, 50:70] = 2
10+
11+
feature_table_1 = regionprops_table(
12+
label_image, properties=("label", "area", "perimeter")
13+
)
14+
feature_table_1["index"] = feature_table_1["label"]
15+
16+
# make the points data
17+
n_points = 100
18+
points_data = 100 * np.random.random((100, 2))
19+
points_features = {
20+
"feature_0": np.random.random((n_points,)),
21+
"feature_1": np.random.random((n_points,)),
22+
"feature_2": np.random.random((n_points,)),
23+
}
24+
25+
# create the viewer
26+
viewer = napari.Viewer()
27+
viewer.add_labels(label_image, features=feature_table_1)
28+
viewer.add_points(points_data, features=points_features)
29+
30+
# make the widget
31+
viewer.window.add_plugin_dock_widget(
32+
plugin_name="napari-matplotlib", widget_name="FeaturesScatter"
33+
)
34+
35+
if __name__ == "__main__":
36+
napari.run()

src/napari_matplotlib/base.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,14 @@ def setup_callbacks(self) -> None:
8787
# z-step changed in viewer
8888
self.viewer.dims.events.current_step.connect(self._draw)
8989
# Layer selection changed in viewer
90-
self.viewer.layers.selection.events.active.connect(self.update_layers)
90+
self.viewer.layers.selection.events.changed.connect(self.update_layers)
9191

9292
def update_layers(self, event: napari.utils.events.Event) -> None:
9393
"""
9494
Update the layers attribute with currently selected layers and re-draw.
9595
"""
9696
self.layers = list(self.viewer.layers.selection)
97+
self._on_update_layers()
9798
self._draw()
9899

99100
def _draw(self) -> None:
@@ -103,6 +104,7 @@ def _draw(self) -> None:
103104
"""
104105
self.clear()
105106
if self.n_selected_layers != self.n_layers_input:
107+
self.canvas.draw()
106108
return
107109
self.draw()
108110
self.canvas.draw()
@@ -120,6 +122,14 @@ def draw(self) -> None:
120122
121123
This is a no-op, and is intended for derived classes to override.
122124
"""
125+
126+
127+
def _on_update_layers(self) -> None:
128+
"""This function is called when self.layers is updated via self.update_layers()
129+
130+
This is a no-op, and is intended for derived classes to override.
131+
"""
132+
123133
def _replace_toolbar_icons(self):
124134
# Modify toolbar icons and some tooltips
125135
for action in self.toolbar.actions():

src/napari_matplotlib/napari.yaml

+8-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ contributions:
88

99
- id: napari-matplotlib.scatter
1010
python_name: napari_matplotlib:ScatterWidget
11-
title: Make a scatter plot
11+
title: Make a scatter plot of image intensities
12+
13+
- id: napari-matplotlib.features_scatter
14+
python_name: napari_matplotlib:FeaturesScatterWidget
15+
title: Make a scatter plot of layer features
1216

1317
- id: napari-matplotlib.slice
1418
python_name: napari_matplotlib:SliceWidget
@@ -21,5 +25,8 @@ contributions:
2125
- command: napari-matplotlib.scatter
2226
display_name: Scatter
2327

28+
- command: napari-matplotlib.features_scatter
29+
display_name: FeaturesScatter
30+
2431
- command: napari-matplotlib.slice
2532
display_name: 1D slice

src/napari_matplotlib/scatter.py

+201-15
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,225 @@
1+
from typing import List, Tuple, Union
2+
13
import matplotlib.colors as mcolor
24
import napari
5+
import numpy as np
6+
from magicgui import magicgui
37

48
from .base import NapariMPLWidget
59

6-
__all__ = ["ScatterWidget"]
10+
__all__ = ["ScatterWidget", "FeaturesScatterWidget"]
711

812

9-
class ScatterWidget(NapariMPLWidget):
10-
"""
11-
Widget to display scatter plot of two similarly shaped layers.
13+
class ScatterBaseWidget(NapariMPLWidget):
14+
# opacity value for the markers
15+
_marker_alpha = 0.5
1216

13-
If there are more than 500 data points, a 2D histogram is displayed instead
14-
of a scatter plot, to avoid too many scatter points.
15-
"""
17+
# flag set to True if histogram should be used
18+
# for plotting large points
19+
_histogram_for_large_data = True
1620

17-
n_layers_input = 2
21+
# if the number of points is greater than this value,
22+
# the scatter is plotted as a 2dhist
23+
_threshold_to_switch_to_histogram = 500
1824

19-
def __init__(self, napari_viewer: napari.viewer.Viewer):
25+
def __init__(
26+
self,
27+
napari_viewer: napari.viewer.Viewer,
28+
):
2029
super().__init__(napari_viewer)
30+
2131
self.axes = self.canvas.figure.subplots()
2232
self.update_layers(None)
2333

34+
def clear(self) -> None:
35+
self.axes.clear()
36+
2437
def draw(self) -> None:
2538
"""
2639
Clear the axes and scatter the currently selected layers.
2740
"""
28-
data = [layer.data[self.current_z] for layer in self.layers]
29-
if data[0].size < 500:
30-
self.axes.scatter(data[0], data[1], alpha=0.5)
31-
else:
41+
data, x_axis_name, y_axis_name = self._get_data()
42+
43+
if len(data) == 0:
44+
# don't plot if there isn't data
45+
return
46+
47+
if self._histogram_for_large_data and (
48+
data[0].size > self._threshold_to_switch_to_histogram
49+
):
3250
self.axes.hist2d(
3351
data[0].ravel(),
3452
data[1].ravel(),
3553
bins=100,
3654
norm=mcolor.LogNorm(),
3755
)
38-
self.axes.set_xlabel(self.layers[0].name)
39-
self.axes.set_ylabel(self.layers[1].name)
56+
else:
57+
self.axes.scatter(data[0], data[1], alpha=self._marker_alpha)
58+
59+
self.axes.set_xlabel(x_axis_name)
60+
self.axes.set_ylabel(y_axis_name)
61+
62+
def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
63+
"""Get the plot data.
64+
65+
This must be implemented on the subclass.
66+
67+
Returns
68+
-------
69+
data : np.ndarray
70+
The list containing the scatter plot data.
71+
x_axis_name : str
72+
The label to display on the x axis
73+
y_axis_name: str
74+
The label to display on the y axis
75+
"""
76+
raise NotImplementedError
77+
78+
79+
class ScatterWidget(ScatterBaseWidget):
80+
"""
81+
Widget to display scatter plot of two similarly shaped image layers.
82+
83+
If there are more than 500 data points, a 2D histogram is displayed instead
84+
of a scatter plot, to avoid too many scatter points.
85+
"""
86+
87+
n_layers_input = 2
88+
89+
def __init__(
90+
self,
91+
napari_viewer: napari.viewer.Viewer,
92+
):
93+
super().__init__(
94+
napari_viewer,
95+
)
96+
97+
def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
98+
"""Get the plot data.
99+
100+
Returns
101+
-------
102+
data : List[np.ndarray]
103+
List contains the in view slice of X and Y axis images.
104+
x_axis_name : str
105+
The title to display on the x axis
106+
y_axis_name: str
107+
The title to display on the y axis
108+
"""
109+
data = [layer.data[self.current_z] for layer in self.layers]
110+
x_axis_name = self.layers[0].name
111+
y_axis_name = self.layers[1].name
112+
113+
return data, x_axis_name, y_axis_name
114+
115+
116+
class FeaturesScatterWidget(ScatterBaseWidget):
117+
n_layers_input = 1
118+
119+
def __init__(
120+
self,
121+
napari_viewer: napari.viewer.Viewer,
122+
key_selection_gui: bool = True,
123+
):
124+
self._key_selection_widget = None
125+
super().__init__(
126+
napari_viewer,
127+
)
128+
129+
if key_selection_gui is True:
130+
self._key_selection_widget = magicgui(
131+
self._set_axis_keys,
132+
x_axis_key={"choices": self._get_valid_axis_keys},
133+
y_axis_key={"choices": self._get_valid_axis_keys},
134+
call_button="plot",
135+
)
136+
self.layout().addWidget(self._key_selection_widget.native)
137+
138+
@property
139+
def x_axis_key(self) -> Union[None, str]:
140+
"""Key to access x axis data from the FeaturesTable"""
141+
return self._x_axis_key
142+
143+
@x_axis_key.setter
144+
def x_axis_key(self, key: Union[None, str]):
145+
self._x_axis_key = key
146+
self._draw()
147+
148+
@property
149+
def y_axis_key(self) -> Union[None, str]:
150+
"""Key to access y axis data from the FeaturesTable"""
151+
return self._y_axis_key
152+
153+
@y_axis_key.setter
154+
def y_axis_key(self, key: Union[None, str]):
155+
self._y_axis_key = key
156+
self._draw()
157+
158+
def _set_axis_keys(self, x_axis_key: str, y_axis_key: str):
159+
"""Set both axis keys and then redraw the plot"""
160+
self._x_axis_key = x_axis_key
161+
self._y_axis_key = y_axis_key
162+
self._draw()
163+
164+
def _get_valid_axis_keys(self, combo_widget=None) -> List[str]:
165+
"""Get the valid axis keys from the layer FeatureTable.
166+
167+
Returns
168+
-------
169+
axis_keys : List[str]
170+
The valid axis keys in the FeatureTable. If the table is empty
171+
or there isn't a table, returns an empty list.
172+
"""
173+
if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")):
174+
return []
175+
else:
176+
return self.layers[0].features.keys()
177+
178+
def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
179+
"""Get the plot data.
180+
181+
Returns
182+
-------
183+
data : List[np.ndarray]
184+
List contains X and Y columns from the FeatureTable. Returns
185+
an empty array if nothing to plot.
186+
x_axis_name : str
187+
The title to display on the x axis. Returns
188+
an empty string if nothing to plot.
189+
y_axis_name: str
190+
The title to display on the y axis. Returns
191+
an empty string if nothing to plot.
192+
"""
193+
if not hasattr(self.layers[0], "features"):
194+
# if the selected layer doesn't have a featuretable,
195+
# skip draw
196+
return np.array([]), "", ""
197+
198+
feature_table = self.layers[0].features
199+
200+
if (
201+
(len(feature_table) == 0)
202+
or (self.x_axis_key is None)
203+
or (self.y_axis_key is None)
204+
):
205+
return np.array([]), "", ""
206+
207+
data_x = feature_table[self.x_axis_key]
208+
data_y = feature_table[self.y_axis_key]
209+
data = [data_x, data_y]
210+
211+
x_axis_name = self.x_axis_key.replace("_", " ")
212+
y_axis_name = self.y_axis_key.replace("_", " ")
213+
214+
return data, x_axis_name, y_axis_name
215+
216+
def _on_update_layers(self) -> None:
217+
"""This is called when the layer selection changes
218+
by self.update_layers().
219+
"""
220+
if self._key_selection_widget is not None:
221+
self._key_selection_widget.reset_choices()
222+
223+
# reset the axis keys
224+
self._x_axis_key = None
225+
self._y_axis_key = None

0 commit comments

Comments
 (0)