Skip to content

Commit

Permalink
Deal with the mad aggregation being removed in Pandas 2 (#602)
Browse files Browse the repository at this point in the history
  • Loading branch information
bartbroere authored Nov 6, 2023
1 parent 5b3a83e commit 5e5f36b
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions tests/dataframe/test_groupby_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@

from tests.common import TestData

PANDAS_MAJOR_VERSION = int(pd.__version__.split(".")[0])


# The mean absolute difference (mad) aggregation has been removed from
# pandas with major version 2:
# https://github.com/pandas-dev/pandas/issues/11787
# To compare whether eland's version of it works, we need to implement
# it here ourselves.
def mad(x):
return abs(x - x.mean()).mean()


class TestGroupbyDataFrame(TestData):
funcs = ["max", "min", "mean", "sum"]
Expand Down Expand Up @@ -71,7 +82,7 @@ def test_groupby_aggregate_single_aggs(self, pd_agg):
@pytest.mark.parametrize("dropna", [True, False])
@pytest.mark.parametrize("pd_agg", ["max", "min", "mean", "sum", "median"])
def test_groupby_aggs_numeric_only_true(self, pd_agg, dropna):
# Pandas has numeric_only applicable for the above aggs with groupby only.
# Pandas has numeric_only applicable for the above aggs with groupby only.

pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)
Expand All @@ -95,7 +106,14 @@ def test_groupby_aggs_mad_var_std(self, pd_agg, dropna):
pd_flights = self.pd_flights().filter(self.filter_data)
ed_flights = self.ed_flights().filter(self.filter_data)

pd_groupby = getattr(pd_flights.groupby("Cancelled", dropna=dropna), pd_agg)()
# The mad aggregation has been removed in Pandas 2, so we need to use
# our own implementation if we run the tests with Pandas 2 or higher
if PANDAS_MAJOR_VERSION >= 2 and pd_agg == "mad":
pd_groupby = pd_flights.groupby("Cancelled", dropna=dropna).aggregate(mad)
else:
pd_groupby = getattr(
pd_flights.groupby("Cancelled", dropna=dropna), pd_agg
)()
ed_groupby = getattr(ed_flights.groupby("Cancelled", dropna=dropna), pd_agg)(
numeric_only=True
)
Expand Down Expand Up @@ -211,14 +229,20 @@ def test_groupby_dataframe_mad(self):
pd_flights = self.pd_flights().filter(self.filter_data + ["DestCountry"])
ed_flights = self.ed_flights().filter(self.filter_data + ["DestCountry"])

pd_mad = pd_flights.groupby("DestCountry").mad()
if PANDAS_MAJOR_VERSION < 2:
pd_mad = pd_flights.groupby("DestCountry").mad()
else:
pd_mad = pd_flights.groupby("DestCountry").aggregate(mad)
ed_mad = ed_flights.groupby("DestCountry").mad()

assert_index_equal(pd_mad.columns, ed_mad.columns)
assert_index_equal(pd_mad.index, ed_mad.index)
assert_series_equal(pd_mad.dtypes, ed_mad.dtypes)

pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", "mad"])
if PANDAS_MAJOR_VERSION < 2:
pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", "mad"])
else:
pd_min_mad = pd_flights.groupby("DestCountry").aggregate(["min", mad])
ed_min_mad = ed_flights.groupby("DestCountry").aggregate(["min", "mad"])

assert_index_equal(pd_min_mad.columns, ed_min_mad.columns)
Expand Down

0 comments on commit 5e5f36b

Please sign in to comment.