Skip to content

Commit adf2b99

Browse files
fix na_value handling for categorical case + update tests for expected categorical behaviour
1 parent 2dfd50b commit adf2b99

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

pandas/core/arrays/categorical.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2678,8 +2678,12 @@ def _str_map(
26782678
codes = self.codes
26792679
if categories.dtype == "string":
26802680
result = categories.array._str_map(f, na_value, dtype)
2681-
if categories.dtype.na_value is np.nan:
2682-
# NaN propagates as False
2681+
if (
2682+
categories.dtype.na_value is np.nan
2683+
and is_bool_dtype(dtype)
2684+
and (na_value is lib.no_default or isna(na_value))
2685+
):
2686+
# NaN propagates as False for functions with boolean return type
26832687
na_value = False
26842688
else:
26852689
from pandas.core.arrays import NumpyExtensionArray

pandas/tests/strings/test_api.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def test_api_per_method(
122122
any_allowed_skipna_inferred_dtype,
123123
any_string_method,
124124
request,
125+
using_infer_string,
125126
):
126127
# this test does not check correctness of the different methods,
127128
# just that the methods work on the specified (inferred) dtypes,
@@ -160,6 +161,10 @@ def test_api_per_method(
160161
t = box(values, dtype=dtype) # explicit dtype to avoid casting
161162
method = getattr(t.str, method_name)
162163

164+
if using_infer_string and dtype == "category":
165+
string_allowed = method_name not in ["decode"]
166+
else:
167+
string_allowed = True
163168
bytes_allowed = method_name in ["decode", "get", "len", "slice"]
164169
# as of v0.23.4, all methods except 'cat' are very lenient with the
165170
# allowed data types, just returning NaN for entries that error.
@@ -168,7 +173,8 @@ def test_api_per_method(
168173
mixed_allowed = method_name not in ["cat"]
169174

170175
allowed_types = (
171-
["string", "unicode", "empty"]
176+
["empty"]
177+
+ ["string", "unicode"] * string_allowed
172178
+ ["bytes"] * bytes_allowed
173179
+ ["mixed", "mixed-integer"] * mixed_allowed
174180
)
@@ -182,6 +188,7 @@ def test_api_per_method(
182188
msg = (
183189
f"Cannot use .str.{method_name} with values of "
184190
f"inferred dtype {inferred_dtype!r}."
191+
"|a bytes-like object is required, not 'str'"
185192
)
186193
with pytest.raises(TypeError, match=msg):
187194
method(*args, **kwargs)

pandas/tests/strings/test_find_replace.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def test_startswith_endswith_validate_na(any_string_dtype):
320320
@pytest.mark.parametrize("dtype", ["object", "category"])
321321
@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
322322
@pytest.mark.parametrize("na", [True, False])
323-
def test_startswith(pat, dtype, null_value, na):
323+
def test_startswith(pat, dtype, null_value, na, using_infer_string):
324324
# add category dtype parametrizations for GH-36241
325325
values = Series(
326326
["om", null_value, "foo_nom", "nom", "bar_foo", null_value, "foo"],
@@ -334,6 +334,8 @@ def test_startswith(pat, dtype, null_value, na):
334334
exp = exp.fillna(null_value)
335335
elif dtype == "object" and null_value is None:
336336
exp[exp.isna()] = None
337+
elif using_infer_string and dtype == "category":
338+
exp = exp.fillna(False).astype(bool)
337339
tm.assert_series_equal(result, exp)
338340

339341
result = values.str.startswith(pat, na=na)
@@ -389,7 +391,7 @@ def test_startswith_string_dtype(any_string_dtype, na):
389391
@pytest.mark.parametrize("dtype", ["object", "category"])
390392
@pytest.mark.parametrize("null_value", [None, np.nan, pd.NA])
391393
@pytest.mark.parametrize("na", [True, False])
392-
def test_endswith(pat, dtype, null_value, na):
394+
def test_endswith(pat, dtype, null_value, na, using_infer_string):
393395
# add category dtype parametrizations for GH-36241
394396
values = Series(
395397
["om", null_value, "foo_nom", "nom", "bar_foo", null_value, "foo"],
@@ -403,6 +405,8 @@ def test_endswith(pat, dtype, null_value, na):
403405
exp = exp.fillna(null_value)
404406
elif dtype == "object" and null_value is None:
405407
exp[exp.isna()] = None
408+
elif using_infer_string and dtype == "category":
409+
exp = exp.fillna(False).astype(bool)
406410
tm.assert_series_equal(result, exp)
407411

408412
result = values.str.endswith(pat, na=na)

0 commit comments

Comments
 (0)