Skip to content

Commit 7be8721

Browse files
committed
Updating the datetime encoder example to add periodic features
1 parent e160dcb commit 7be8721

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

examples/03_datetime_encoder.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
.. |make_column_transformer| replace::
3333
:class:`~sklearn.compose.make_column_transformer`
3434
35-
.. |HGBR| replace::
36-
:class:`~sklearn.ensemble.HistGradientBoostingRegressor`
35+
.. |RidgeCV| replace::
36+
:class:`~sklearn.linear_model.RidgeCV`
3737
3838
.. |ToDatetime| replace::
3939
:class:`~skrub.ToDatetime`
@@ -128,20 +128,39 @@
128128
# |DatetimeEncoder| is used on the correct column(s).
129129
pprint(table_vec_weekday.transformers_)
130130

131+
###############################################################################
132+
# The |DatetimeEncoder| can generate additional periodic features using either
133+
# B-Splines (|SplineTransformer|) or trigonometric functions. To do so, set the
134+
# ``periodic encoding`` parameter ``circular`` or ``spline``. In this
135+
# example, we use ``spline``.
136+
137+
table_vec_periodic = TableVectorizer(
138+
datetime=DatetimeEncoder(
139+
periodic_encoding="spline",
140+
)
141+
).fit(X)
142+
131143
###############################################################################
132144
# Prediction with datetime features
133145
# ---------------------------------
134146
#
135147
# For prediction tasks, we recommend using the |TableVectorizer| inside a
136148
# pipeline, combined with a model that can use the features extracted by the
137149
# |DatetimeEncoder|.
138-
# Here we'll use a |HGBR| as our learner.
150+
# Here we'll use a |RidgeCV| model as our learner.
139151
#
140-
from sklearn.ensemble import HistGradientBoostingRegressor
152+
from sklearn.impute import SimpleImputer
153+
from sklearn.linear_model import RidgeCV
141154
from sklearn.pipeline import make_pipeline
155+
from sklearn.preprocessing import StandardScaler
142156

143-
pipeline = make_pipeline(table_vec, HistGradientBoostingRegressor())
144-
pipeline_weekday = make_pipeline(table_vec_weekday, HistGradientBoostingRegressor())
157+
pipeline = make_pipeline(table_vec, StandardScaler(), SimpleImputer(), RidgeCV())
158+
pipeline_weekday = make_pipeline(
159+
table_vec_weekday, StandardScaler(), SimpleImputer(), RidgeCV()
160+
)
161+
pipeline_periodic = make_pipeline(
162+
table_vec_periodic, StandardScaler(), SimpleImputer(), RidgeCV()
163+
)
145164

146165
###############################################################################
147166
# Evaluating the model

0 commit comments

Comments
 (0)