diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 6642f5855f4fe..c8c784132c084 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -80,6 +80,7 @@ Other enhancements - Added half-year offset classes :class:`HalfYearBegin`, :class:`HalfYearEnd`, :class:`BHalfYearBegin` and :class:`BHalfYearEnd` (:issue:`60928`) - Added support to read from Apache Iceberg tables with the new :func:`read_iceberg` function (:issue:`61383`) - Errors occurring during SQL I/O will now throw a generic :class:`.DatabaseError` instead of the raw Exception type from the underlying driver manager library (:issue:`60748`) +- Implemented :meth:`MultiIndex.searchsorted` (:issue:`14833`) - Implemented :meth:`Series.str.isascii` and :meth:`Series.str.isascii` (:issue:`59091`) - Improved deprecation message for offset aliases (:issue:`60820`) - Multiplying two :class:`DateOffset` objects will now raise a ``TypeError`` instead of a ``RecursionError`` (:issue:`59442`) @@ -87,7 +88,6 @@ Other enhancements - Support passing a :class:`Iterable[Hashable]` input to :meth:`DataFrame.drop_duplicates` (:issue:`59237`) - Support reading Stata 102-format (Stata 1) dta files (:issue:`58978`) - Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`) - .. --------------------------------------------------------------------------- .. _whatsnew_300.notable_bug_fixes: diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 29b34f560ab2e..7032979a51b3a 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -16,6 +16,7 @@ Any, Literal, cast, + overload, ) import warnings @@ -44,6 +45,15 @@ Shape, npt, ) + +if TYPE_CHECKING: + from pandas._typing import ( + NumpySorter, + NumpyValueArrayLike, + ScalarLike_co, + ) + + from pandas.compat.numpy import function as nv from pandas.errors import ( InvalidIndexError, @@ -3778,6 +3788,99 @@ def _reorder_indexer( ind = np.lexsort(keys) return indexer[ind] + @overload + def searchsorted( # type: ignore[overload-overlap] + self, + value: ScalarLike_co, + side: Literal["left", "right"] = ..., + sorter: NumpySorter = ..., + ) -> np.intp: ... + + @overload + def searchsorted( + self, + value: npt.ArrayLike | ExtensionArray, + side: Literal["left", "right"] = ..., + sorter: NumpySorter = ..., + ) -> npt.NDArray[np.intp]: ... + + def searchsorted( + self, + value: NumpyValueArrayLike | ExtensionArray, + side: Literal["left", "right"] = "left", + sorter: npt.NDArray[np.intp] | None = None, + ) -> npt.NDArray[np.intp] | np.intp: + """ + Find the indices where elements should be inserted to maintain order. + + Parameters + ---------- + value : Any + The value(s) to search for in the MultiIndex. + side : {'left', 'right'}, default 'left' + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. Note that if `value` is + already present in the MultiIndex, the results will be different. + sorter : 1-D array-like, optional + Optional array of integer indices that sort the MultiIndex. + + Returns + ------- + npt.NDArray[np.intp] or np.intp + The index or indices where the value(s) should be inserted to + maintain order. + + See Also + -------- + Index.searchsorted : Search for insertion point in a 1-D index. + + Examples + -------- + >>> mi = pd.MultiIndex.from_arrays([["a", "b", "c"], ["x", "y", "z"]]) + >>> mi.searchsorted(("b", "y")) + array([1]) + """ + + if not value: + raise ValueError("searchsorted requires a non-empty value") + + if not isinstance(value, (tuple, list)): + raise TypeError("value must be a tuple or list") + + if isinstance(value, tuple): + value = [value] + + if side not in ["left", "right"]: + raise ValueError("side must be either 'left' or 'right'") + + indexer = self.get_indexer(value) + result = [] + + for v, i in zip(value, indexer): + if i != -1: + val = i if side == "left" else i + 1 + result.append(np.intp(val)) + else: + dtype = np.dtype( + [ + (f"level_{i}", np.asarray(level).dtype) + for i, level in enumerate(self.levels) + ] + ) + + val_array = np.array([v], dtype=dtype) + + pos = np.searchsorted( + np.asarray(self.values, dtype=dtype), + val_array, + side=side, + sorter=sorter, + ) + result.append(np.intp(pos[0])) + if len(result) == 1: + return result[0] + return np.array(result, dtype=np.intp) + def truncate(self, before=None, after=None) -> MultiIndex: """ Slice index between two labels / tuples, return new MultiIndex. diff --git a/pandas/tests/base/test_misc.py b/pandas/tests/base/test_misc.py index 7819b7b75f065..41f6d3a6c292d 100644 --- a/pandas/tests/base/test_misc.py +++ b/pandas/tests/base/test_misc.py @@ -147,14 +147,7 @@ def test_searchsorted(request, index_or_series_obj): # See gh-12238 obj = index_or_series_obj - if isinstance(obj, pd.MultiIndex): - # See gh-14833 - request.applymarker( - pytest.mark.xfail( - reason="np.searchsorted doesn't work on pd.MultiIndex: GH 14833" - ) - ) - elif obj.dtype.kind == "c" and isinstance(obj, Index): + if obj.dtype.kind == "c" and isinstance(obj, Index): # TODO: Should Series cases also raise? Looks like they use numpy # comparison semantics https://github.com/numpy/numpy/issues/15981 mark = pytest.mark.xfail(reason="complex objects are not comparable") diff --git a/pandas/tests/indexes/multi/test_indexing.py b/pandas/tests/indexes/multi/test_indexing.py index f098690be2afa..4b4524ecb198a 100644 --- a/pandas/tests/indexes/multi/test_indexing.py +++ b/pandas/tests/indexes/multi/test_indexing.py @@ -1029,3 +1029,57 @@ def test_get_loc_namedtuple_behaves_like_tuple(): assert idx.get_loc(("i1", "i2")) == 0 assert idx.get_loc(("i3", "i4")) == 1 assert idx.get_loc(("i5", "i6")) == 2 + + +@pytest.fixture +def mi(): + return MultiIndex.from_tuples([("a", 0), ("a", 1), ("b", 0), ("b", 1), ("c", 0)]) + + +@pytest.mark.parametrize( + "value, side, expected", + [ + (("b", 0), "left", 2), + (("b", 0), "right", 3), + (("a", 0), "left", 0), + (("a", -1), "left", 0), + (("c", 1), "left", 5), + ], + ids=[ + "b-0-left", + "b-0-right", + "a-0-left", + "a--1-left", + "c-1-left-beyond", + ], +) +def test_searchsorted_single(mi, value, side, expected): + # GH14833 + result = mi.searchsorted(value, side=side) + assert np.all(result == expected) + + +@pytest.mark.parametrize( + "values, side, expected", + [ + ([("a", 1), ("b", 0), ("c", 0)], "left", np.array([1, 2, 4], dtype=np.intp)), + ([("a", 1), ("b", 0), ("c", 0)], "right", np.array([2, 3, 5], dtype=np.intp)), + ], + ids=["list-left", "list-right"], +) +def test_searchsorted_list(mi, values, side, expected): + result = mi.searchsorted(values, side=side) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize( + "value, side, error_type, match", + [ + (("a", 1), "middle", ValueError, "side must be either 'left' or 'right'"), + ("a", "left", TypeError, "value must be a tuple or list"), + ], + ids=["invalid-side", "invalid-value-type"], +) +def test_searchsorted_invalid(mi, value, side, error_type, match): + with pytest.raises(error_type, match=match): + mi.searchsorted(value, side=side)