Skip to content

API: change ExtensionArray.astype to always return an ExtensionArray #43699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
isin,
unique,
)
from pandas.core.construction import array as pd_array
from pandas.core.sorting import (
nargminmax,
nargsort,
Expand Down Expand Up @@ -535,21 +536,9 @@ def nbytes(self) -> int:
# Additional Methods
# ------------------------------------------------------------------------

@overload
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...

@overload
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
...

@overload
def astype(self, dtype: AstypeArg, copy: bool = ...) -> ArrayLike:
...

def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
def astype(self, dtype: AstypeArg, copy: bool = True) -> ExtensionArray:
"""
Cast to a NumPy array or ExtensionArray with 'dtype'.
Cast to an ExtensionArray with 'dtype'.

Parameters
----------
Expand All @@ -562,9 +551,9 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:

Returns
-------
array : np.ndarray or ExtensionArray
An ExtensionArray if dtype is ExtensionDtype,
Otherwise a NumPy ndarray with 'dtype' for its dtype.
ExtensionArray
ExtensionArray subclass, depending on `dtype`. If `dtype` is a
numpy dtype, this returns a PandasArray.
"""

dtype = pandas_dtype(dtype)
Expand All @@ -578,7 +567,7 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
cls = dtype.construct_array_type()
return cls._from_sequence(self, dtype=dtype, copy=copy)

return np.array(self, dtype=dtype, copy=copy)
return pd_array(self, dtype=dtype, copy=copy)

def isna(self) -> np.ndarray | ExtensionArraySupportsAnyAll:
"""
Expand Down Expand Up @@ -757,7 +746,7 @@ def fillna(
if mask.any():
if method is not None:
func = missing.get_fill_func(method)
new_values, _ = func(self.astype(object), limit=limit, mask=mask)
new_values, _ = func(self.to_numpy(object), limit=limit, mask=mask)
new_values = self._from_sequence(new_values, dtype=self.dtype)
else:
# fill with value
Expand Down Expand Up @@ -836,7 +825,7 @@ def unique(self: ExtensionArrayT) -> ExtensionArrayT:
-------
uniques : ExtensionArray
"""
uniques = unique(self.astype(object))
uniques = unique(self.to_numpy(object))
return self._from_sequence(uniques, dtype=self.dtype)

def searchsorted(
Expand Down Expand Up @@ -888,9 +877,9 @@ def searchsorted(
# 1. Values outside the range of the `data_for_sorting` fixture
# 2. Values between the values in the `data_for_sorting` fixture
# 3. Missing values.
arr = self.astype(object)
arr = self.to_numpy(object)
if isinstance(value, ExtensionArray):
value = value.astype(object)
value = value.to_numpy(object)
return arr.searchsorted(value, side=side, sorter=sorter)

def equals(self, other: object) -> bool:
Expand Down Expand Up @@ -965,7 +954,7 @@ def _values_for_factorize(self) -> tuple[np.ndarray, Any]:
The values returned by this method are also used in
:func:`pandas.util.hash_pandas_object`.
"""
return self.astype(object), np.nan
return self.to_numpy(object), np.nan

def factorize(self, na_sentinel: int = -1) -> tuple[np.ndarray, ExtensionArray]:
"""
Expand Down
31 changes: 8 additions & 23 deletions pandas/core/arrays/boolean.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from __future__ import annotations

import numbers
from typing import (
TYPE_CHECKING,
overload,
)
from typing import TYPE_CHECKING
import warnings

import numpy as np
Expand All @@ -14,11 +11,9 @@
missing as libmissing,
)
from pandas._typing import (
ArrayLike,
AstypeArg,
Dtype,
DtypeObj,
npt,
type_t,
)
from pandas.compat.numpy import function as nv
Expand All @@ -44,6 +39,7 @@
BaseMaskedArray,
BaseMaskedDtype,
)
from pandas.core.arrays.numpy_ import PandasArray

if TYPE_CHECKING:
import pyarrow
Expand Down Expand Up @@ -367,22 +363,10 @@ def map_string(s):
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
return coerce_to_array(value)

@overload
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...

@overload
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
...

@overload
def astype(self, dtype: AstypeArg, copy: bool = ...) -> ArrayLike:
...

def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
def astype(self, dtype: AstypeArg, copy: bool = True) -> ExtensionArray:

"""
Cast to a NumPy array or ExtensionArray with 'dtype'.
Cast to an ExtensionArray with 'dtype'.

Parameters
----------
Expand All @@ -395,8 +379,8 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:

Returns
-------
ndarray or ExtensionArray
NumPy ndarray, BooleanArray or IntegerArray with 'dtype' for its dtype.
ExtensionArray
ExtensionArray subclass, depending on `dtype`.

Raises
------
Expand Down Expand Up @@ -426,7 +410,8 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
if is_float_dtype(dtype):
na_value = np.nan
# coerce
return self.to_numpy(dtype=dtype, na_value=na_value, copy=False)
arr = self.to_numpy(dtype=dtype, na_value=na_value, copy=False)
return PandasArray(arr)

def _values_for_argsort(self) -> np.ndarray:
"""
Expand Down
28 changes: 7 additions & 21 deletions pandas/core/arrays/floating.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from typing import overload
import warnings

import numpy as np
Expand All @@ -10,10 +9,8 @@
missing as libmissing,
)
from pandas._typing import (
ArrayLike,
AstypeArg,
DtypeObj,
npt,
)
from pandas.compat.numpy import function as nv
from pandas.util._decorators import cache_readonly
Expand All @@ -39,6 +36,7 @@
NumericArray,
NumericDtype,
)
from pandas.core.arrays.numpy_ import PandasArray
from pandas.core.ops import invalid_comparison
from pandas.core.tools.numeric import to_numeric

Expand Down Expand Up @@ -278,21 +276,9 @@ def _from_sequence_of_strings(
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
return coerce_to_array(value, dtype=self.dtype)

@overload
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...

@overload
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
...

@overload
def astype(self, dtype: AstypeArg, copy: bool = ...) -> ArrayLike:
...

def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
def astype(self, dtype: AstypeArg, copy: bool = True) -> ExtensionArray:
"""
Cast to a NumPy array or ExtensionArray with 'dtype'.
Cast to an ExtensionArray with 'dtype'.

Parameters
----------
Expand All @@ -305,9 +291,8 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:

Returns
-------
ndarray or ExtensionArray
NumPy ndarray, or BooleanArray, IntegerArray or FloatingArray with
'dtype' for its dtype.
ExtensionArray
ExtensionArray subclass, depending on `dtype`.

Raises
------
Expand All @@ -334,7 +319,8 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
# error: Argument 2 to "to_numpy" of "BaseMaskedArray" has incompatible
# type "**Dict[str, float]"; expected "bool"
data = self.to_numpy(dtype=dtype, **kwargs) # type: ignore[arg-type]
return astype_nansafe(data, dtype, copy=False)
arr = astype_nansafe(data, dtype, copy=False)
return PandasArray(arr)

def _values_for_argsort(self) -> np.ndarray:
return self._data
Expand Down
27 changes: 7 additions & 20 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from typing import overload
import warnings

import numpy as np
Expand All @@ -11,11 +10,9 @@
missing as libmissing,
)
from pandas._typing import (
ArrayLike,
AstypeArg,
Dtype,
DtypeObj,
npt,
)
from pandas.compat.numpy import function as nv
from pandas.util._decorators import cache_readonly
Expand Down Expand Up @@ -46,6 +43,7 @@
NumericArray,
NumericDtype,
)
from pandas.core.arrays.numpy_ import PandasArray
from pandas.core.ops import invalid_comparison
from pandas.core.tools.numeric import to_numeric

Expand Down Expand Up @@ -345,21 +343,9 @@ def _from_sequence_of_strings(
def _coerce_to_array(self, value) -> tuple[np.ndarray, np.ndarray]:
return coerce_to_array(value, dtype=self.dtype)

@overload
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...

@overload
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
...

@overload
def astype(self, dtype: AstypeArg, copy: bool = ...) -> ArrayLike:
...

def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
def astype(self, dtype: AstypeArg, copy: bool = True) -> ExtensionArray:
"""
Cast to a NumPy array or ExtensionArray with 'dtype'.
Cast to an ExtensionArray with 'dtype'.

Parameters
----------
Expand All @@ -372,8 +358,8 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:

Returns
-------
ndarray or ExtensionArray
NumPy ndarray, BooleanArray or IntegerArray with 'dtype' for its dtype.
ExtensionArray
ExtensionArray subclass, depending on `dtype`.

Raises
------
Expand All @@ -397,7 +383,8 @@ def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
else:
na_value = lib.no_default

return self.to_numpy(dtype=dtype, na_value=na_value, copy=False)
arr = self.to_numpy(dtype=dtype, na_value=na_value, copy=False)
return PandasArray(arr)

def _values_for_argsort(self) -> np.ndarray:
"""
Expand Down
15 changes: 1 addition & 14 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
missing as libmissing,
)
from pandas._typing import (
ArrayLike,
AstypeArg,
NpDtype,
PositionalIndexer,
Expand Down Expand Up @@ -361,19 +360,7 @@ def to_numpy(
data = self._data.astype(dtype, copy=copy)
return data

@overload
def astype(self, dtype: npt.DTypeLike, copy: bool = ...) -> np.ndarray:
...

@overload
def astype(self, dtype: ExtensionDtype, copy: bool = ...) -> ExtensionArray:
...

@overload
def astype(self, dtype: AstypeArg, copy: bool = ...) -> ArrayLike:
...

def astype(self, dtype: AstypeArg, copy: bool = True) -> ArrayLike:
def astype(self, dtype: AstypeArg, copy: bool = True) -> ExtensionArray:
dtype = pandas_dtype(dtype)

if is_dtype_equal(dtype, self.dtype):
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def astype(self, dtype, copy=True):
arr[mask] = 0
values = arr.astype(dtype)
values[mask] = np.nan
return values
return PandasArray(values)

return super().astype(dtype, copy)

Expand Down
5 changes: 4 additions & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,10 @@ def astype_array(values: ArrayLike, dtype: DtypeObj, copy: bool = False) -> Arra

if not isinstance(values, np.ndarray):
# i.e. ExtensionArray
values = values.astype(dtype, copy=copy)
if isinstance(dtype, ExtensionDtype):
values = values.astype(dtype, copy=copy)
else:
values = values.to_numpy(dtype, copy=False)

else:
values = astype_nansafe(values, dtype, copy=copy)
Expand Down
18 changes: 18 additions & 0 deletions pandas/tests/extension/base/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,31 @@
import pandas.util._test_decorators as td

import pandas as pd
from pandas.core.arrays import ExtensionArray
from pandas.core.internals import ObjectBlock
from pandas.tests.extension.base.base import BaseExtensionTests


class BaseCastingTests(BaseExtensionTests):
"""Casting to and from ExtensionDtypes"""

def test_extension_astype_returns_extension_array(self, all_data):
# Base test to ensure that EA.astype always returns an EA
# https://github.com/pandas-dev/pandas/issues/24877

# test for some dtype strings and objects that would map to numpy dtypes
for dtype_str in ["object", "int64", "float64"]:
for dtype in [dtype_str, np.dtype(dtype_str)]:
try:
# only taking first two elements to have one case without NAs
res = all_data[:2].astype(dtype)
except (TypeError, ValueError):
# not all casts will be possible, so ignore TypeError/ValueErrors
# and only check the result if there is a return value
pass
else:
assert isinstance(res, ExtensionArray)

def test_astype_object_series(self, all_data):
ser = pd.Series(all_data, name="A")
result = ser.astype(object)
Expand Down
Loading