diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 3a90feb7ccd7d..c45838e6040a9 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -1807,7 +1807,7 @@ def _get_dtype(arr_or_dtype): return arr_or_dtype elif isinstance(arr_or_dtype, type): return np.dtype(arr_or_dtype) - elif isinstance(arr_or_dtype, CategoricalDtype): + elif isinstance(arr_or_dtype, ExtensionDtype): return arr_or_dtype elif isinstance(arr_or_dtype, DatetimeTZDtype): return arr_or_dtype diff --git a/pandas/core/internals.py b/pandas/core/internals.py index 37d11296400be..e98899b2f5c1a 100644 --- a/pandas/core/internals.py +++ b/pandas/core/internals.py @@ -5541,8 +5541,14 @@ def concatenate_join_units(join_units, concat_axis, copy): if len(to_concat) == 1: # Only one block, nothing to concatenate. concat_values = to_concat[0] - if copy and concat_values.base is not None: - concat_values = concat_values.copy() + if copy: + if isinstance(concat_values, np.ndarray): + # non-reindexed (=not yet copied) arrays are made into a view + # in JoinUnit.get_reindexed_values + if concat_values.base is not None: + concat_values = concat_values.copy() + else: + concat_values = concat_values.copy() else: concat_values = _concat._concat_compat(to_concat, axis=concat_axis) @@ -5823,7 +5829,7 @@ def get_reindexed_values(self, empty_dtype, upcasted_na): # External code requested filling/upcasting, bool values must # be upcasted to object to avoid being upcasted to numeric. values = self.block.astype(np.object_).values - elif self.block.is_categorical: + elif self.block.is_extension: values = self.block.values else: # No dtype upcasting is done here, it will be performed during diff --git a/pandas/tests/extension/base/reshaping.py b/pandas/tests/extension/base/reshaping.py index efc22c19a3eef..f50222b82df0f 100644 --- a/pandas/tests/extension/base/reshaping.py +++ b/pandas/tests/extension/base/reshaping.py @@ -95,3 +95,24 @@ def test_set_frame_overwrite_object(self, data): df = pd.DataFrame({"A": [1] * len(data)}, dtype=object) df['A'] = data assert df.dtypes['A'] == data.dtype + + def test_merge(self, data, na_value): + # GH-20743 + df1 = pd.DataFrame({'ext': data[:3], 'int1': [1, 2, 3], + 'key': [0, 1, 2]}) + df2 = pd.DataFrame({'int2': [1, 2, 3, 4], 'key': [0, 0, 1, 3]}) + + res = pd.merge(df1, df2) + exp = pd.DataFrame( + {'int1': [1, 1, 2], 'int2': [1, 2, 3], 'key': [0, 0, 1], + 'ext': data._constructor_from_sequence( + [data[0], data[0], data[1]])}) + self.assert_frame_equal(res, exp[['ext', 'int1', 'key', 'int2']]) + + res = pd.merge(df1, df2, how='outer') + exp = pd.DataFrame( + {'int1': [1, 1, 2, 3, np.nan], 'int2': [1, 2, 3, np.nan, 4], + 'key': [0, 0, 1, 2, 3], + 'ext': data._constructor_from_sequence( + [data[0], data[0], data[1], data[2], na_value])}) + self.assert_frame_equal(res, exp[['ext', 'int1', 'key', 'int2']]) diff --git a/pandas/tests/extension/category/test_categorical.py b/pandas/tests/extension/category/test_categorical.py index 27c156c15203f..6ebe700f13be0 100644 --- a/pandas/tests/extension/category/test_categorical.py +++ b/pandas/tests/extension/category/test_categorical.py @@ -75,6 +75,10 @@ def test_align(self, data, na_value): def test_align_frame(self, data, na_value): pass + @pytest.mark.skip(reason="Unobserved categories preseved in concat.") + def test_merge(self, data, na_value): + pass + class TestGetitem(base.BaseGetitemTests): @pytest.mark.skip(reason="Backwards compatibility") diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index d509170565e1a..53d74cd6d38cb 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -72,6 +72,14 @@ def assert_series_equal(self, left, right, *args, **kwargs): def assert_frame_equal(self, left, right, *args, **kwargs): # TODO(EA): select_dtypes + tm.assert_index_equal( + left.columns, right.columns, + exact=kwargs.get('check_column_type', 'equiv'), + check_names=kwargs.get('check_names', True), + check_exact=kwargs.get('check_exact', False), + check_categorical=kwargs.get('check_categorical', True), + obj='{obj}.columns'.format(obj=kwargs.get('obj', 'DataFrame'))) + decimals = (left.dtypes == 'decimal').index for col in decimals: