Skip to content

Commit d05c0b9

Browse files
authored
PERF: DataFrame.where for EA dtype mask (#51574)
1 parent 89f6a12 commit d05c0b9

File tree

3 files changed

+37
-26
lines changed

3 files changed

+37
-26
lines changed

asv_bench/benchmarks/frame_methods.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,21 @@ def time_memory_usage_object_dtype(self):
770770
self.df2.memory_usage(deep=True)
771771

772772

773+
class Where:
774+
params = (
775+
[True, False],
776+
["float64", "Float64", "float64[pyarrow]"],
777+
)
778+
param_names = ["dtype"]
779+
780+
def setup(self, inplace, dtype):
781+
self.df = DataFrame(np.random.randn(100_000, 10), dtype=dtype)
782+
self.mask = self.df < 0
783+
784+
def time_where(self, inplace, dtype):
785+
self.df.where(self.mask, other=0.0, inplace=inplace)
786+
787+
773788
class FindValidIndex:
774789
param_names = ["dtype"]
775790
params = [

doc/source/whatsnew/v2.1.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ Deprecations
100100

101101
Performance improvements
102102
~~~~~~~~~~~~~~~~~~~~~~~~
103+
- Performance improvement in :meth:`DataFrame.where` when ``cond`` is backed by an extension dtype (:issue:`51574`)
103104
- Performance improvement in :meth:`~arrays.ArrowExtensionArray.isna` when array has zero nulls or is all nulls (:issue:`51630`)
104105
- Performance improvement in :meth:`DataFrame.first_valid_index` and :meth:`DataFrame.last_valid_index` for extension array dtypes (:issue:`51549`)
105106
- Performance improvement in :meth:`DataFrame.clip` and :meth:`Series.clip` (:issue:`51472`)

pandas/core/generic.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7989,33 +7989,22 @@ def _clip_with_scalar(self, lower, upper, inplace: bool_t = False):
79897989
):
79907990
raise ValueError("Cannot use an NA value as a clip threshold")
79917991

7992-
mgr = self._mgr
7992+
result = self
7993+
mask = self.isna()
79937994

7994-
if inplace:
7995-
# cond (for putmask) identifies values to be updated.
7996-
# exclude boundary as values at the boundary should be no-ops.
7997-
if upper is not None:
7998-
cond = self > upper
7999-
mgr = mgr.putmask(mask=cond, new=upper, align=False)
8000-
if lower is not None:
8001-
cond = self < lower
8002-
mgr = mgr.putmask(mask=cond, new=lower, align=False)
8003-
else:
8004-
# cond (for where) identifies values to be left as-is.
8005-
# include boundary as values at the boundary should be no-ops.
8006-
mask = isna(self)
8007-
if upper is not None:
8008-
cond = mask | (self <= upper)
8009-
mgr = mgr.where(other=upper, cond=cond, align=False)
8010-
if lower is not None:
8011-
cond = mask | (self >= lower)
8012-
mgr = mgr.where(other=lower, cond=cond, align=False)
8013-
8014-
result = self._constructor(mgr)
8015-
if inplace:
8016-
return self._update_inplace(result)
8017-
else:
8018-
return result.__finalize__(self)
7995+
if lower is not None:
7996+
cond = mask | (self >= lower)
7997+
result = result.where(
7998+
cond, lower, inplace=inplace
7999+
) # type: ignore[assignment]
8000+
if upper is not None:
8001+
cond = mask | (self <= upper)
8002+
result = self if inplace else result
8003+
result = result.where(
8004+
cond, upper, inplace=inplace
8005+
) # type: ignore[assignment]
8006+
8007+
return result
80198008

80208009
@final
80218010
def _clip_with_one_bound(self, threshold, method, axis, inplace):
@@ -9629,6 +9618,12 @@ def _where(
96299618
for _dt in cond.dtypes:
96309619
if not is_bool_dtype(_dt):
96319620
raise ValueError(msg.format(dtype=_dt))
9621+
if cond._mgr.any_extension_types:
9622+
# GH51574: avoid object ndarray conversion later on
9623+
cond = cond._constructor(
9624+
cond.to_numpy(dtype=bool, na_value=fill_value),
9625+
**cond._construct_axes_dict(),
9626+
)
96329627
else:
96339628
# GH#21947 we have an empty DataFrame/Series, could be object-dtype
96349629
cond = cond.astype(bool)

0 commit comments

Comments
 (0)