-
Notifications
You must be signed in to change notification settings - Fork 22
Add Features scatter plot #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
5a58dd9
ef1029b
91928e1
22fd3ee
e79eeaf
9d3a042
fc03d78
0340c54
9f4cbb9
f2a28d6
e1941b0
86abe7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import napari | ||
import numpy as np | ||
from skimage.measure import regionprops_table | ||
|
||
from napari_matplotlib.scatter import FeaturesScatterWidget | ||
|
||
# make a test label image | ||
label_image = np.zeros((100, 100), dtype=np.uint16) | ||
|
||
label_image[10:20, 10:20] = 1 | ||
label_image[50:70, 50:70] = 2 | ||
|
||
feature_table_1 = regionprops_table( | ||
label_image, properties=("label", "area", "perimeter") | ||
) | ||
feature_table_1["index"] = feature_table_1["label"] | ||
|
||
# make the points data | ||
n_points = 100 | ||
points_data = 100 * np.random.random((100, 2)) | ||
points_features = { | ||
"feature_0": np.random.random((n_points,)), | ||
"feature_1": np.random.random((n_points,)), | ||
"feature_2": np.random.random((n_points,)), | ||
} | ||
|
||
# create the viewer | ||
viewer = napari.Viewer() | ||
viewer.add_labels(label_image, features=feature_table_1) | ||
viewer.add_points(points_data, features=points_features) | ||
|
||
# make the widget | ||
features_widget = FeaturesScatterWidget(viewer, histogram_for_large_data=False) | ||
viewer.window.add_dock_widget(features_widget) | ||
|
||
if __name__ == "__main__": | ||
napari.run() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,239 @@ | ||
from typing import List, Tuple, Union | ||
|
||
import matplotlib.colors as mcolor | ||
import napari | ||
import numpy as np | ||
from magicgui import magicgui | ||
|
||
from .base import NapariMPLWidget | ||
|
||
__all__ = ["ScatterWidget"] | ||
__all__ = ["ScatterWidget", "FeaturesScatterWidget"] | ||
|
||
|
||
class ScatterWidget(NapariMPLWidget): | ||
""" | ||
Widget to display scatter plot of two similarly shaped layers. | ||
class ScatterBaseWidget(NapariMPLWidget): | ||
def __init__( | ||
self, | ||
napari_viewer: napari.viewer.Viewer, | ||
marker_alpha: float = 0.5, | ||
histogram_for_large_data: bool = True, | ||
kevinyamauchi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
super().__init__(napari_viewer) | ||
|
||
If there are more than 500 data points, a 2D histogram is displayed instead | ||
of a scatter plot, to avoid too many scatter points. | ||
""" | ||
# flag set to True if histogram should be used | ||
# for plotting large points | ||
self.histogram_for_large_data = histogram_for_large_data | ||
|
||
n_layers_input = 2 | ||
# set plotting visualization attributes | ||
self._marker_alpha = 0.5 | ||
|
||
def __init__(self, napari_viewer: napari.viewer.Viewer): | ||
super().__init__(napari_viewer) | ||
self.axes = self.canvas.figure.subplots() | ||
self.update_layers(None) | ||
|
||
@property | ||
def marker_alpha(self) -> float: | ||
"""Alpha (opacity) for the scatter markers""" | ||
return self._marker_alpha | ||
|
||
@marker_alpha.setter | ||
def marker_alpha(self, alpha: float): | ||
self._marker_alpha = alpha | ||
self._draw() | ||
|
||
def clear(self) -> None: | ||
self.axes.clear() | ||
|
||
def draw(self) -> None: | ||
""" | ||
Clear the axes and scatter the currently selected layers. | ||
""" | ||
data = [layer.data[self.current_z] for layer in self.layers] | ||
if data[0].size < 500: | ||
self.axes.scatter(data[0], data[1], alpha=0.5) | ||
else: | ||
data, x_axis_name, y_axis_name = self._get_data() | ||
|
||
if len(data) == 0: | ||
# don't plot if there isn't data | ||
return | ||
|
||
if self.histogram_for_large_data and (data[0].size > 500): | ||
self.axes.hist2d( | ||
data[0].ravel(), | ||
data[1].ravel(), | ||
bins=100, | ||
norm=mcolor.LogNorm(), | ||
) | ||
self.axes.set_xlabel(self.layers[0].name) | ||
self.axes.set_ylabel(self.layers[1].name) | ||
else: | ||
self.axes.scatter(data[0], data[1], alpha=self.marker_alpha) | ||
|
||
self.axes.set_xlabel(x_axis_name) | ||
self.axes.set_ylabel(y_axis_name) | ||
|
||
def _get_data(self) -> Tuple[List[np.ndarray], str, str]: | ||
"""Get the plot data. | ||
|
||
This must be implemented on the subclass. | ||
|
||
Returns | ||
------- | ||
data : np.ndarray | ||
The list containing the scatter plot data. | ||
x_axis_name : str | ||
The title to display on the x axis | ||
kevinyamauchi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
y_axis_name: str | ||
The title to display on the y axis | ||
kevinyamauchi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class ScatterWidget(ScatterBaseWidget): | ||
""" | ||
Widget to display scatter plot of two similarly shaped layers. | ||
kevinyamauchi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
If there are more than 500 data points, a 2D histogram is displayed instead | ||
of a scatter plot, to avoid too many scatter points. | ||
""" | ||
|
||
n_layers_input = 2 | ||
|
||
def __init__( | ||
self, | ||
napari_viewer: napari.viewer.Viewer, | ||
marker_alpha: float = 0.5, | ||
histogram_for_large_data: bool = True, | ||
): | ||
super().__init__( | ||
napari_viewer, | ||
marker_alpha=marker_alpha, | ||
histogram_for_large_data=histogram_for_large_data, | ||
) | ||
|
||
def _get_data(self) -> Tuple[List[np.ndarray], str, str]: | ||
"""Get the plot data. | ||
|
||
Returns | ||
------- | ||
data : List[np.ndarray] | ||
List contains the in view slice of X and Y axis images. | ||
x_axis_name : str | ||
The title to display on the x axis | ||
y_axis_name: str | ||
The title to display on the y axis | ||
""" | ||
data = [layer.data[self.current_z] for layer in self.layers] | ||
x_axis_name = self.layers[0].name | ||
y_axis_name = self.layers[1].name | ||
|
||
return data, x_axis_name, y_axis_name | ||
|
||
|
||
class FeaturesScatterWidget(ScatterBaseWidget): | ||
n_layers_input = 1 | ||
|
||
def __init__( | ||
self, | ||
napari_viewer: napari.viewer.Viewer, | ||
marker_alpha: float = 0.5, | ||
histogram_for_large_data: bool = True, | ||
key_selection_gui: bool = True, | ||
): | ||
self._key_selection_widget = None | ||
super().__init__( | ||
napari_viewer, | ||
marker_alpha=marker_alpha, | ||
histogram_for_large_data=histogram_for_large_data, | ||
) | ||
|
||
if key_selection_gui is True: | ||
self._key_selection_widget = magicgui( | ||
self._set_axis_keys, | ||
x_axis_key={"choices": self._get_valid_axis_keys}, | ||
y_axis_key={"choices": self._get_valid_axis_keys}, | ||
call_button="plot", | ||
) | ||
self.layout().addWidget(self._key_selection_widget.native) | ||
Comment on lines
+130
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is awesome - I've been trying to get my head around how to use magicgui to do something like this to create part of a GUI for a while! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool! Let me know if you have questions - happy to chat! |
||
|
||
@property | ||
def x_axis_key(self) -> Union[None, str]: | ||
"""Key to access x axis data from the FeaturesTable""" | ||
return self._x_axis_key | ||
|
||
@x_axis_key.setter | ||
def x_axis_key(self, key: Union[None, str]): | ||
self._x_axis_key = key | ||
self._draw() | ||
|
||
@property | ||
def y_axis_key(self) -> Union[None, str]: | ||
"""Key to access y axis data from the FeaturesTable""" | ||
return self._y_axis_key | ||
|
||
@y_axis_key.setter | ||
def y_axis_key(self, key: Union[None, str]): | ||
self._y_axis_key = key | ||
self._draw() | ||
|
||
def _set_axis_keys(self, x_axis_key: str, y_axis_key: str): | ||
"""Set both axis keys and then redraw the plot""" | ||
self._x_axis_key = x_axis_key | ||
self._y_axis_key = y_axis_key | ||
self._draw() | ||
|
||
def _get_valid_axis_keys(self, combo_widget=None) -> List[str]: | ||
"""Get the valid axis keys from the layer FeatureTable. | ||
|
||
Returns | ||
------- | ||
axis_keys : List[str] | ||
The valid axis keys in the FeatureTable. If the table is empty | ||
or there isn't a table, returns an empty list. | ||
""" | ||
if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")): | ||
return [] | ||
else: | ||
return self.layers[0].features.keys() | ||
|
||
def _get_data(self) -> Tuple[List[np.ndarray], str, str]: | ||
"""Get the plot data. | ||
|
||
Returns | ||
------- | ||
data : List[np.ndarray] | ||
List contains X and Y columns from the FeatureTable. Returns | ||
an empty array if nothing to plot. | ||
x_axis_name : str | ||
The title to display on the x axis. Returns | ||
an empty string if nothing to plot. | ||
y_axis_name: str | ||
The title to display on the y axis. Returns | ||
an empty string if nothing to plot. | ||
""" | ||
if not hasattr(self.layers[0], "features"): | ||
# if the selected layer doesn't have a featuretable, | ||
# skip draw | ||
return np.array([]), "", "" | ||
|
||
feature_table = self.layers[0].features | ||
|
||
if ( | ||
(len(feature_table) == 0) | ||
or (self.x_axis_key is None) | ||
or (self.y_axis_key is None) | ||
): | ||
return np.array([]), "", "" | ||
|
||
data_x = feature_table[self.x_axis_key] | ||
data_y = feature_table[self.y_axis_key] | ||
data = [data_x, data_y] | ||
|
||
x_axis_name = self.x_axis_key.replace("_", " ") | ||
y_axis_name = self.y_axis_key.replace("_", " ") | ||
|
||
return data, x_axis_name, y_axis_name | ||
|
||
def _on_update_layers(self) -> None: | ||
"""This is called when the layer selection changes | ||
by self.update_layers(). | ||
""" | ||
if self._key_selection_widget is not None: | ||
self._key_selection_widget.reset_choices() | ||
|
||
# reset the axis keys | ||
self._x_axis_key = None | ||
self._y_axis_key = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah thanks for catching this!