diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c6824b62..f5616c4a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,6 +25,7 @@ repos: rev: v1.2.0 hooks: - id: mypy + additional_dependencies: [numpy, matplotlib] - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py index 6c0d3d90..cb1e8498 100644 --- a/src/napari_matplotlib/scatter.py +++ b/src/napari_matplotlib/scatter.py @@ -1,8 +1,8 @@ -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple import matplotlib.colors as mcolor import napari -import numpy as np +import numpy.typing as npt from magicgui import magicgui from magicgui.widgets import ComboBox @@ -65,7 +65,7 @@ def draw(self) -> None: self.axes.set_xlabel(x_axis_name) self.axes.set_ylabel(y_axis_name) - def _get_data(self) -> Tuple[List[np.ndarray], str, str]: + def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: """Get the plot data. This must be implemented on the subclass. @@ -93,7 +93,7 @@ class ScatterWidget(ScatterBaseWidget): n_layers_input = Interval(2, 2) input_layer_types = (napari.layers.Image,) - def _get_data(self) -> Tuple[List[np.ndarray], str, str]: + def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: """ Get the plot data. @@ -191,7 +191,7 @@ def _get_valid_axis_keys( else: return self.layers[0].features.keys() - def _get_data(self) -> Tuple[List[np.ndarray], str, str]: + def _get_data(self) -> Tuple[List[npt.NDArray[Any]], str, str]: """ Get the plot data. diff --git a/src/napari_matplotlib/slice.py b/src/napari_matplotlib/slice.py index bd8d219a..4e22bad8 100644 --- a/src/napari_matplotlib/slice.py +++ b/src/napari_matplotlib/slice.py @@ -1,7 +1,8 @@ -from typing import Dict, Tuple +from typing import Any, Dict, Tuple import napari import numpy as np +import numpy.typing as npt from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QSpinBox from .base import NapariMPLWidget @@ -87,7 +88,7 @@ def update_slice_selectors(self) -> None: for i, dim in enumerate(_dims_sel): self.slice_selectors[dim].setRange(0, self.layer.data.shape[i]) - def get_xy(self) -> Tuple[np.ndarray, np.ndarray]: + def get_xy(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any]]: """ Get data for plotting. """ diff --git a/src/napari_matplotlib/tests/test_scatter.py b/src/napari_matplotlib/tests/test_scatter.py index b2349014..fe07655d 100644 --- a/src/napari_matplotlib/tests/test_scatter.py +++ b/src/napari_matplotlib/tests/test_scatter.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Tuple import numpy as np +import numpy.typing as npt from napari_matplotlib import FeaturesScatterWidget, ScatterWidget @@ -22,7 +23,7 @@ def test_features_scatter_widget(make_napari_viewer): def make_labels_layer_with_features() -> ( - Tuple[np.ndarray, Dict[str, Tuple[Any]]] + Tuple[npt.NDArray[np.uint16], Dict[str, Any]] ): label_image = np.zeros((100, 100), dtype=np.uint16) for label_value, start_index in enumerate([10, 30, 50], start=1):