diff --git a/CHANGES.rst b/CHANGES.rst index 3933466e1..f622af997 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -12,12 +12,15 @@ Ongoing development New features ------------ -- The :class:`TableReport` now switch it's visual theme between light and dark according to the user preferences. +- The :class:`TableReport` now switches its visual theme between light and dark according to the user preferences. :pr:`1201` by :user:`rouk1 `. - Adding a new way to control the location of the data directory, using envar `SKRUB_DATA_DIRECTORY`. :pr:`1215` by :user:`Thomas S. ` +- The :class:`DatetimeEncoder` now supports periodic encoding of the features using circular (sine/cosine) and spline + transformers. :pr:`1235` by :user:`Riccardo Cappuzzo`. + Changes ------- diff --git a/examples/03_datetime_encoder.py b/examples/03_datetime_encoder.py index 25ba8401a..53053836f 100644 --- a/examples/03_datetime_encoder.py +++ b/examples/03_datetime_encoder.py @@ -32,13 +32,13 @@ .. |make_column_transformer| replace:: :class:`~sklearn.compose.make_column_transformer` -.. |HGBR| replace:: - :class:`~sklearn.ensemble.HistGradientBoostingRegressor` +.. |RidgeCV| replace:: + :class:`~sklearn.linear_model.RidgeCV` .. |ToDatetime| replace:: :class:`~skrub.ToDatetime` """ - +# %% ############################################################################### # A problem with relevant datetime features # ----------------------------------------- @@ -128,6 +128,19 @@ # |DatetimeEncoder| is used on the correct column(s). pprint(table_vec_weekday.transformers_) +############################################################################### +# The |DatetimeEncoder| can generate additional periodic features using either +# B-Splines (|SplineTransformer|) or trigonometric functions. To do so, set the +# ``periodic encoding`` parameter ``circular`` or ``spline``. In this +# example, we use ``spline``. +# We can also add the day in the year with the parameter ``add_day_of_year``. + +table_vec_periodic = TableVectorizer( + datetime=DatetimeEncoder( + add_weekday=True, periodic_encoding="spline", add_day_of_year=True + ) +).fit(X) + ############################################################################### # Prediction with datetime features # --------------------------------- @@ -135,13 +148,14 @@ # For prediction tasks, we recommend using the |TableVectorizer| inside a # pipeline, combined with a model that can use the features extracted by the # |DatetimeEncoder|. -# Here we'll use a |HGBR| as our learner. +# Here we'll use a |RidgeCV| model as our learner. # from sklearn.ensemble import HistGradientBoostingRegressor from sklearn.pipeline import make_pipeline pipeline = make_pipeline(table_vec, HistGradientBoostingRegressor()) pipeline_weekday = make_pipeline(table_vec_weekday, HistGradientBoostingRegressor()) +pipeline_periodic = make_pipeline(table_vec_periodic, HistGradientBoostingRegressor()) ############################################################################### # Evaluating the model @@ -155,14 +169,29 @@ # which ensures that the test set is always in the future. from sklearn.model_selection import TimeSeriesSplit, cross_val_score -cross_val_score( +score_base = cross_val_score( pipeline, X, y, - scoring="neg_mean_squared_error", + scoring="neg_root_mean_squared_error", + cv=TimeSeriesSplit(n_splits=5), +) + +score_periodic = cross_val_score( + pipeline_periodic, + X, + y, + scoring="neg_root_mean_squared_error", cv=TimeSeriesSplit(n_splits=5), ) +print(f"Base transformer - Mean RMSE : {-score_base.mean():.2f}") +print(f"Transformer with periodic features - Mean RMSE : {-score_periodic.mean():.2f}") + +############################################################################### +# As expected, introducing new features improved the RMSE by a noticeable amount. + + ############################################################################### # Plotting the prediction # ....................... @@ -184,10 +213,16 @@ pipeline_weekday.fit(X_train, y_train) y_pred_weekday = pipeline_weekday.predict(X_test) +pipeline_periodic.fit(X_train, y_train) +y_pred_periodic = pipeline_periodic.predict(X_test) + +X_plot = pd.to_datetime(X.tail(96)["date"]).values +X_test_plot = pd.to_datetime(X_test.tail(96)["date"]).values + fig, ax = plt.subplots(figsize=(12, 3)) fig.suptitle("Predictions with tree models") ax.plot( - X.tail(96)["date"], + X_plot, y.tail(96).values, "x-", alpha=0.2, @@ -195,18 +230,38 @@ color="black", ) ax.plot( - X_test.tail(96)["date"], + X_test_plot, y_pred[-96:], "x-", - label="DatetimeEncoder() + HGBR prediction", + label="DatetimeEncoder() + HGBDT prediction", ) ax.plot( - X_test.tail(96)["date"], + X_test_plot, y_pred_weekday[-96:], "x-", - label="DatetimeEncoder(add_weekday=True) + HGBR prediction", + label="DatetimeEncoder(add_weekday=True) + HGBDT prediction", +) + +ax.plot( + X_test_plot, + y_pred_periodic[-96:], + "x-", + label='DatetimeEncoder(periodic_encoding="spline") + HGBDT prediction', +) + + +ax.xaxis.set_major_locator(mdates.DayLocator()) +ax.xaxis.set_minor_locator( + mdates.HourLocator( + [0, 6, 12, 18], + ) ) +# Major formatter: format date as "YYYY-MM-DD" +ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d")) +# # Minor formatter: format time as "HH:MM" +ax.xaxis.set_minor_formatter(mdates.DateFormatter("%H:%M")) + ax.tick_params(axis="x", labelsize=7, labelrotation=75) ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d")) _ = ax.legend() @@ -231,10 +286,10 @@ # In this case, we don't use the whole pipeline, because we want to compute the # importance of the features created by the DatetimeEncoder -X_test_transform = pipeline[:-1].transform(X_test) +X_test_transform = pipeline_weekday[:-1].transform(X_test) result = permutation_importance( - pipeline[-1], X_test_transform, y_test, n_repeats=10, random_state=0 + pipeline_weekday[-1], X_test_transform, y_test, n_repeats=10, random_state=0 ) result = pd.DataFrame( @@ -256,8 +311,8 @@ plt.show() # %% -# We can see that the hour of the day, the temperature and the humidity -# are the most important features, which seems reasonable. +# We can see that the hour of the day, the temperature, the day of the week +# and the humidity are the most important features, which seems reasonable. # # Conclusion # ---------- @@ -266,3 +321,5 @@ # features from a datetime column. # Also check out the |TableVectorizer|, which automatically recognizes # and transforms datetime columns by default. + +# %% diff --git a/skrub/_dataframe/_common.py b/skrub/_dataframe/_common.py index 8f423128d..88ecc2e6e 100644 --- a/skrub/_dataframe/_common.py +++ b/skrub/_dataframe/_common.py @@ -97,6 +97,7 @@ "unique", "filter", "where", + "where_row", "sample", "head", "slice", @@ -1182,6 +1183,23 @@ def _where_polars(col, mask, other): return col.zip_with(mask, pl.Series(other)) +@dispatch +def where_row(obj, mask, other): + raise NotImplementedError() + + +@where_row.specialize("pandas") +def _where_row_pandas(obj, mask, other): + return obj.apply(pd.Series.where, **{"cond": mask, "other": other}) + + +@where_row.specialize("polars") +def _where_row_polars(obj, mask, other): + return obj.with_columns( + pl.when(pl.Series(mask)).then(pl.all()).otherwise(pl.Series(other)) + ) + + @dispatch def sample(obj, n, seed=None): raise NotImplementedError() diff --git a/skrub/_dataframe/tests/test_common.py b/skrub/_dataframe/tests/test_common.py index a9830ea05..65acd75b6 100644 --- a/skrub/_dataframe/tests/test_common.py +++ b/skrub/_dataframe/tests/test_common.py @@ -778,6 +778,24 @@ def test_where(df_module): ) +def test_where_row(df_module): + df = df_module.make_dataframe({"col1": [1, 2, 3], "col2": [1000, 2000, 3000]}) + out = ns.where_row( + df, + df_module.make_column("", [False, True, False]), # mask + df_module.make_column( + "", [None, None, None] + ), # values to put in on the entire row + ) + right = df_module.make_dataframe( + {"col1": [None, 2, None], "col2": [None, 2000, None]} + ) + df_module.assert_frame_equal( + ns.pandas_convert_dtypes(out), + ns.pandas_convert_dtypes(right), + ) + + def test_sample(df_module): s = ns.pandas_convert_dtypes(df_module.make_column("", [0, 1, 2])) sample = ns.sample(s, 2) diff --git a/skrub/_datetime_encoder.py b/skrub/_datetime_encoder.py index fea47ded0..22e39637a 100644 --- a/skrub/_datetime_encoder.py +++ b/skrub/_datetime_encoder.py @@ -1,6 +1,8 @@ from datetime import datetime, timezone +import numpy as np import pandas as pd +from sklearn.preprocessing import SplineTransformer from sklearn.utils.validation import check_is_fitted try: @@ -26,6 +28,9 @@ "nanosecond", ] +_DEFAULT_ENCODING_PERIODS = {"year": 366, "month": 30, "weekday": 7, "hour": 24} +_DEFAULT_ENCODING_SPLINES = {"year": 12, "month": 4, "weekday": 7, "hour": 24} + @dispatch def _is_date(col): @@ -58,6 +63,9 @@ def _get_dt_feature_pandas(col, feature): return ((col - epoch) / pd.Timedelta("1s")).astype("float32") if feature == "weekday": return col.dt.day_of_week + 1 + if feature == "day_of_year": + return col.dt.day_of_year + assert feature in _TIME_LEVELS return getattr(col.dt, feature) @@ -66,6 +74,8 @@ def _get_dt_feature_pandas(col, feature): def _get_dt_feature_polars(col, feature): if feature == "total_seconds": return (col.dt.timestamp(time_unit="ms") / 1000).cast(pl.Float32) + if feature == "day_of_year": + return col.dt.ordinal_day() assert feature in _TIME_LEVELS + ["weekday"] return getattr(col.dt, feature)() @@ -105,11 +115,22 @@ class DatetimeEncoder(SingleColumnTransformer): Add the total number of seconds since the Unix epoch (00:00:00 UTC on 1 January 1970). + add_day_of_year : bool, default=False + Add the day of year (ordinal day) as an integer in the range 1 to 365 (or + 366 in the case of leap years). January 1st will be day 1, December 31st + will be day 365 on non-leap years. + + periodic_encoding : str or None, default=None + Add periodic features with different granularities. Add periodic features + using either trigonometric (``circular``) or ``spline`` encoding. + Attributes ---------- extracted_features_ : list of strings The features that are extracted, a subset of ["year", …, "nanosecond", - "weekday", "total_seconds"] + "weekday", "total_seconds"]. If ``add_periodic=True``, the extracted + features will also be added. Given a feature named ``date``, new features + will be named ``date_year_circular_0``, ``date_year_circular_1`` etc. See Also -------- @@ -181,7 +202,7 @@ class DatetimeEncoder(SingleColumnTransformer): 0 2024.0 4.0 14.0 1.713053e+09 1 2024.0 5.0 15.0 1.715731e+09 >>> encoder.extracted_features_ - ['year', 'month', 'day', 'total_seconds'] + ['birthday_year', 'birthday_month', 'birthday_day', 'birthday_total_seconds'] (The number of seconds since Epoch can still be extracted but not "hour", "minute", etc.) @@ -254,12 +275,40 @@ class DatetimeEncoder(SingleColumnTransformer): Here we can see the input to ``transform`` has been converted back to the timezone used during ``fit`` and that we get the same result for "hour". + + The DatetimeEncoder can also create new features based on either trigonometric + functions or splines by setting ``periodic_encoder="circular"`` or ``periodic_encoder="spline"`` + respectively. + (https://scikit-learn.org/stable/auto_examples/applications/plot_cyclical_feature_engineering.html). + + >>> encoder = make_pipeline(ToDatetime(), DatetimeEncoder(periodic_encoding="circular")) + >>> encoder.fit_transform(login) + login_year login_month ... login_hour_circular_0 login_hour_circular_1 + 0 2024.0 5.0 ... 1.224647e-16 -1.000000 + 1 NaN NaN ... NaN NaN + 2 2024.0 5.0 ... -2.588190e-01 -0.965926 + + Added features can be explored using ``DatetimeEncoder.extracted_features_``: + >>> encoder[-1].extracted_features_ + ['login_year', 'login_month', 'login_day', 'login_hour', 'login_total_seconds', + 'login_year_circular_0', 'login_year_circular_1', 'login_month_circular_0', + 'login_month_circular_1', 'login_weekday_circular_0', 'login_weekday_circular_1', + 'login_hour_circular_0', 'login_hour_circular_1'] """ # noqa: E501 - def __init__(self, resolution="hour", add_weekday=False, add_total_seconds=True): + def __init__( + self, + resolution="hour", + add_weekday=False, + add_total_seconds=True, + add_day_of_year=False, + periodic_encoding=None, + ): self.resolution = resolution self.add_weekday = add_weekday self.add_total_seconds = add_total_seconds + self.add_day_of_year = add_day_of_year + self.periodic_encoding = periodic_encoding def fit_transform(self, column, y=None): """Fit the encoder and transform a column. @@ -284,16 +333,50 @@ def fit_transform(self, column, y=None): f"Column {sbd.name(column)!r} does not have Date or Datetime dtype." ) if self.resolution is None: - self.extracted_features_ = [] + self._partial_features = [] else: idx_level = _TIME_LEVELS.index(self.resolution) if _is_date(column): idx_level = min(idx_level, _TIME_LEVELS.index("day")) - self.extracted_features_ = _TIME_LEVELS[: idx_level + 1] + self._partial_features = _TIME_LEVELS[: idx_level + 1] if self.add_total_seconds: - self.extracted_features_.append("total_seconds") + self._partial_features.append("total_seconds") if self.add_weekday: - self.extracted_features_.append("weekday") + self._partial_features.append("weekday") + if self.add_day_of_year: + self._partial_features.append("day_of_year") + + # Adding transformers for periodic encoding + self._required_transformers = {} + + col_name = sbd.name(column) + self.extracted_features_ = [ + f"{col_name}_{_feat}" for _feat in self._partial_features + ] + + # Iterating over all attributes that end with _encoding to use the default + # parameters + if self.periodic_encoding is not None: + enc_attr = ["year", "month", "weekday", "hour"] + for enc_feature in enc_attr: + if self.periodic_encoding == "circular": + self._required_transformers[enc_feature] = _CircularEncoder( + period=_DEFAULT_ENCODING_PERIODS[enc_feature] + ) + elif self.periodic_encoding == "spline": + self._required_transformers[enc_feature] = _SplineEncoder( + period=_DEFAULT_ENCODING_PERIODS[enc_feature], + n_splines=_DEFAULT_ENCODING_SPLINES[enc_feature], + ) + + for _case, t in self._required_transformers.items(): + _feat = _get_dt_feature(column, _case) + _feat_name = sbd.name(_feat) + "_" + _case + _feat = sbd.rename(_feat, _feat_name) + # Filling null values for periodc encoder + t.fit(self._fill_nulls(_feat)) + self.extracted_features_ += t.all_outputs_ + return self.transform(column) def transform(self, column): @@ -311,12 +394,41 @@ def transform(self, column): """ check_is_fitted(self, "extracted_features_") name = sbd.name(column) + + # Checking again which values are null if calling only transform + not_nulls = ~sbd.is_null(column) + # Replacing filled values back with nulls + null_mask = sbd.copy_index(column, sbd.all_null_like(sbd.to_float32(column))) + all_extracted = [] - for feature in self.extracted_features_: + for feature in self._partial_features: extracted = _get_dt_feature(column, feature).rename(f"{name}_{feature}") extracted = sbd.to_float32(extracted) all_extracted.append(extracted) - return sbd.make_dataframe_like(column, all_extracted) + + _new_features = [] + for _case, t in self._required_transformers.items(): + _feat = _get_dt_feature(column, _case) + # filling nulls only to the feature passed to the periodic encoder + _transformed = t.transform(self._fill_nulls(_feat)) + + _new_features.append(_transformed) + + # Setting the index back to that of the input column (pandas shenanigans) + X_out = sbd.copy_index(column, sbd.make_dataframe_like(column, all_extracted)) + X_out = sbd.concat_horizontal(X_out, *_new_features) + + # Censoring all the null features + X_out = sbd.where_row(X_out, not_nulls, null_mask) + + return X_out + + def _fill_nulls(self, column): + # Fill all null values in the column with an arbitrary value + # This value will be replaced by nulls at the end of the transformation + fill_value = 0 + + return sbd.fill_nulls(column, fill_value) def _check_params(self): allowed = _TIME_LEVELS + [None] @@ -325,6 +437,11 @@ def _check_params(self): f"'resolution' options are {allowed}, got {self.resolution!r}." ) + if self.periodic_encoding not in [None, "circular", "spline"]: + raise ValueError( + f"Unsupported value {self.periodic_encoding} for periodic_encoding." + ) + def _more_tags(self): return {"preserves_dtype": []} @@ -332,3 +449,175 @@ def __sklearn_tags__(self): tags = super().__sklearn_tags__() tags.transformer_tags = TransformerTags(preserves_dtype=[]) return tags + + +class _SplineEncoder(SingleColumnTransformer): + """Generate univariate B-spline bases for features. + + This encoder will apply the scikit-learn SplineTransformer to the given + column and produce a dataframe with the encoded features as output. + + Parameters + ---------- + period : int, default=24 + Period of the feature to be used as base for the periodic extrapolation + at the boundaries of the data. + + n_splines : int or None, default=None + Number of splines (features) to be generated. If set to None, ``n_splines`` + is set to be equal to ``period``. + + degree : int, default=3 + Degree of the polynomial used as the spline basis. Must be a non-negative + integer. + """ + + def __init__(self, period=24, n_splines=None, degree=3): + self.period = period + self.n_splines = n_splines + self.degree = degree + + def fit_transform(self, X, y=None): + """Fit the encoder and transform a column. + + Parameters + ---------- + X : pandas or polars Series with dtype Date or Datetime + The input to transform. + + y : None + Ignored. + + Returns + ------- + transformed : DataFrame + The extracted features. + """ + + del y + + self.transformer_ = self._periodic_spline_transformer() + + X_out = self.transformer_.fit_transform(sbd.to_numpy(X).reshape(-1, 1)) + + self.is_fitted = True + self.n_components_ = X_out.shape[1] + + name = sbd.name(X) + self.all_outputs_ = [ + f"{name}_spline_{idx}" for idx in range(self.n_components_) + ] + + return self._post_process(X, X_out) + + def transform(self, X): + """Transform a column. + + Parameters + ---------- + X : pandas or polars Series with dtype Date or Datetime + The input to transform. + + Returns + ------- + transformed : DataFrame + The extracted features. + """ + + X_out = self.transformer_.transform(sbd.to_numpy(X).reshape(-1, 1)) + + return self._post_process(X, X_out) + + def _post_process(self, X, result): + result = sbd.make_dataframe_like(X, dict(zip(self.all_outputs_, result.T))) + result = sbd.copy_index(X, result) + + return result + + def _periodic_spline_transformer(self): + if self.n_splines is None: + self.n_splines = self.period + n_knots = self.n_splines + 1 # periodic and include_bias is True + return SplineTransformer( + degree=self.degree, + n_knots=n_knots, + knots=np.linspace(0, self.period, n_knots).reshape(n_knots, 1), + extrapolation="periodic", + include_bias=True, + ) + + +class _CircularEncoder(SingleColumnTransformer): + """Generate trigonometric features for the given feature. + + This encoder will generate two features corresponding to the sine and cosine + of the feature, based on the given period as output. + + Parameters + ---------- + period : int, default = 24 + Period to be used as basis of the trigonometric function. + """ + + def __init__(self, period=24): + self.period = period + + def fit_transform(self, X, y=None): + """Fit the encoder and transform a column. + + Parameters + ---------- + X : pandas or polars Series with dtype Date or Datetime + The input to transform. + + y : None + Ignored. + + Returns + ------- + transformed : DataFrame + The extracted features. + """ + + del y + + new_features = [ + np.sin(X / self.period * 2 * np.pi), + np.cos(X / self.period * 2 * np.pi), + ] + + self.n_components_ = 2 + + name = sbd.name(X) + self.all_outputs_ = [ + f"{name}_circular_{idx}" for idx in range(self.n_components_) + ] + + return self._post_process(X, new_features) + + def transform(self, X): + """Transform a column. + + Parameters + ---------- + X : pandas or polars Series with dtype Date or Datetime + The input to transform. + + Returns + ------- + transformed : DataFrame + The extracted features. + """ + + new_features = [ + np.sin(X / self.period * 2 * np.pi), + np.cos(X / self.period * 2 * np.pi), + ] + + return self._post_process(X, new_features) + + def _post_process(self, X, result): + result = sbd.make_dataframe_like(X, dict(zip(self.all_outputs_, result))) + result = sbd.copy_index(X, result) + + return result diff --git a/skrub/tests/test_datetime_encoder.py b/skrub/tests/test_datetime_encoder.py index bd1d75c51..b2759cfe4 100644 --- a/skrub/tests/test_datetime_encoder.py +++ b/skrub/tests/test_datetime_encoder.py @@ -5,13 +5,21 @@ from skrub import DatetimeEncoder from skrub import _dataframe as sbd from skrub import _selectors as s +from skrub._datetime_encoder import _CircularEncoder, _SplineEncoder from skrub._on_each_column import OnEachColumn from skrub._to_float32 import ToFloat32 def date(df_module): return sbd.to_datetime( - df_module.make_column("when", ["2020-01-01", None, "2022-01-01"]), + df_module.make_column( + "when", + [ + "2020-01-01", + None, + "2022-01-01", + ], + ), "%Y-%m-%d", ) @@ -20,7 +28,11 @@ def datetime(df_module): return sbd.to_datetime( df_module.make_column( "when", - ["2020-01-01 10:12:01", None, "2022-01-01 23:23:43"], + [ + "2020-01-01 10:12:01", + None, + "2022-01-01 23:23:43", + ], ), "%Y-%m-%d %H:%M:%S", ) @@ -95,6 +107,7 @@ def expected_features(df_module): "when_total_seconds": [1577873536.0, None, 1641079424.0], "when_weekday": [3.0, None, 6.0], } + res = df_module.make_dataframe(values) return OnEachColumn(ToFloat32()).fit_transform(res) @@ -107,7 +120,7 @@ def test_fit_transform(a_datetime_col, expected_features, df_module, use_fit_tra res = enc.fit(a_datetime_col).transform(a_datetime_col) expected_features = s.select( expected_features, - [f"when_{f}" for f in enc.extracted_features_], + [f"{f}" for f in enc.extracted_features_], ) df_module.assert_frame_equal(res, expected_features, rtol=1e-4) @@ -137,19 +150,49 @@ def test_fit_transform(a_datetime_col, expected_features, df_module, use_fit_tra dict(add_weekday=True, add_total_seconds=False), ["year", "month", "day", "hour", "weekday"], ), + ( + dict(add_day_of_year=True, add_total_seconds=False), + ["year", "month", "day", "hour", "day_of_year"], + ), + ( + dict( + add_day_of_year=False, + add_total_seconds=False, + periodic_encoding="circular", + ), + [ + "year", + "month", + "day", + "hour", + "year_circular_0", + "year_circular_1", + "month_circular_0", + "month_circular_1", + "weekday_circular_0", + "weekday_circular_1", + "hour_circular_0", + "hour_circular_1", + ], + ), ], ) def test_extracted_features_choice(datetime_cols, params, extracted_features): enc = DatetimeEncoder(**params) res = enc.fit_transform(datetime_cols.datetime) - assert enc.extracted_features_ == extracted_features - assert sbd.column_names(res) == [f"when_{f}" for f in enc.extracted_features_] + assert enc.extracted_features_ == [f"when_{f}" for f in extracted_features] + assert sbd.column_names(res) == [f"{f}" for f in enc.extracted_features_] def test_time_not_extracted_from_date_col(datetime_cols): enc = DatetimeEncoder(resolution="nanosecond") enc.fit(datetime_cols.date) - assert enc.extracted_features_ == ["year", "month", "day", "total_seconds"] + assert enc.extracted_features_ == [ + "when_year", + "when_month", + "when_day", + "when_total_seconds", + ] def test_invalid_resolution(datetime_cols): @@ -160,3 +203,44 @@ def test_invalid_resolution(datetime_cols): def test_reject_non_datetime(df_module): with pytest.raises(ValueError, match=".*does not have Date or Datetime dtype."): DatetimeEncoder().fit_transform(df_module.example_column) + + +# Checking parameters for CircularEncoder and SplineEncoder +@pytest.mark.parametrize( + "params, transformers", + [ + ( + dict( + periodic_encoding="circular", + ), + [_CircularEncoder, _CircularEncoder, _CircularEncoder, _CircularEncoder], + ), + ( + dict( + periodic_encoding="spline", + ), + [_SplineEncoder, _SplineEncoder, _SplineEncoder, _SplineEncoder], + ), + ], +) +def test_correct_parameters(a_datetime_col, params, transformers): + enc = DatetimeEncoder(**params) + + enc.fit_transform(a_datetime_col) + + assert all( + [ + isinstance(t, required_t) + for t, required_t in zip(enc._required_transformers.values(), transformers) + ] + ) + + with pytest.raises(ValueError, match="Unsupported value wrongvalue .*"): + DatetimeEncoder(periodic_encoding="wrongvalue").fit_transform(a_datetime_col) + + +def test_error_checking_periodic_encoder(a_datetime_col): + enc = DatetimeEncoder(periodic_encoding="notaparameter") + + with pytest.raises(ValueError, match=r"Unsupported value (\S+) for (\S+)"): + enc.fit_transform(a_datetime_col)