Skip to content

BUG: GH10355 groupby std() doesnt sqrt grouping cols #11507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.17.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Bug Fixes
- Bug in merging ``datetime64[ns, tz]`` dtypes (:issue:`11405`)
- Bug in ``HDFStore.select`` when comparing with a numpy scalar in a where clause (:issue:`11283`)
- Bug in using ``DataFrame.ix`` with a multi-index indexer(:issue:`11372`)
- Bug in ``groupby.var()`` when using ``DataFrame.groupby`` with as_index=False (:issue:`10355`)


- Bug in tz-conversions with an ambiguous time and ``.dt`` accessors (:issue:`11295`)
Expand Down
20 changes: 19 additions & 1 deletion pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,25 @@ def std(self, ddof=1):
For multiple groupings, the result index will be a MultiIndex
"""
# todo, implement at cython level?
return np.sqrt(self.var(ddof=ddof))
if self.as_index:
return np.sqrt(self.var(ddof=ddof))
else:
df = self.var(ddof=ddof)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need all of this logic, just use self._selected_obj

# if we are selecting columns we musn't root them
try:
if type(self.keys) in (list, tuple) and all(k in df.columns for
k in self.keys):
to_root = [c for c in df.columns if c not in self.keys]
elif self.keys in df.columns:
to_root = [c for c in df.columns if c != self.keys]
else:
return np.sqrt(df)
except TypeError:
return np.sqrt(df)

df[to_root] = np.sqrt(df[to_root])
return df

def var(self, ddof=1):
"""
Expand Down
54 changes: 54 additions & 0 deletions pandas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -5546,6 +5546,60 @@ def test_nunique_with_object(self):
expected = pd.Series([1] * 5, name='name', index=index)
tm.assert_series_equal(result, expected)

def test_std_with_as_index_false_and_column_grouping(self):
# GH 10355
df = pd.DataFrame({
'a': [1, 1, 1, 2, 2, 2, 3, 3, 3],
'b': [1, 2, 3, 4, 5, 6, 7, 8, 9],
})
sd = df.groupby('a', as_index=False).std()
sd2 = df.groupby(['a'], as_index=False).std()
sd3 = df.groupby(('a'), as_index=False).std()

expected = pd.DataFrame({
'a': [1, 2, 3],
'b': [1.0, 1.0, 1.0],
})
tm.assert_frame_equal(expected, sd)
tm.assert_frame_equal(expected, sd2)
tm.assert_frame_equal(expected, sd3)

def test_std_with_as_index_false_and_index_grouping(self):
df = pd.DataFrame({
'a': [1, 1, 1, 2, 2, 2, 3, 4, 5],
'b': [1, 2, 3, 4, 5, 6, 7, 8, 9],
})
sd = df.groupby(lambda x: 1+int(x/3), as_index=False).std()
sd2 = df.groupby([1, 1, 1, 2, 2, 2, 3, 3, 3], as_index=False).std()
s = pd.Series([1, 1, 1, 2, 2, 2, 3, 3, 3])
sd3 = df.groupby(s, as_index=False).std()
sd4 = df.groupby(dict(s), as_index=False).std()

expected = pd.DataFrame({
'a': [0.0, 0.0, 1.0],
'b': [1.0, 1.0, 1.0],
})
tm.assert_frame_equal(expected, sd)
tm.assert_frame_equal(expected, sd2)
tm.assert_frame_equal(expected, sd3)
tm.assert_frame_equal(expected, sd4)

def test_std_with_as_index_false_and_ddof_zero(self):
df = pd.DataFrame({
1: [1, 1, 1, 2, 2, 2, 3, 3, 3],
2: [1, 2, 3, 1, 5, 6, 7, 8, 10],
})
sd = df.groupby(1, as_index=False).std(ddof=0)

expected = pd.DataFrame({
1: [1, 2, 3],
2: [
np.std([1, 2, 3], ddof=0),
np.std([1, 5, 6], ddof=0),
np.std([7, 8, 10], ddof=0)],
})
tm.assert_frame_equal(expected, sd)


def assert_fp_equal(a, b):
assert (np.abs(a - b) < 1e-12).all()
Expand Down