Skip to content

Commit 97011bd

Browse files
ENH: optimize StringEncoder (#1248)
Co-authored-by: Jerome Dockes <[email protected]>
1 parent 3b19394 commit 97011bd

File tree

3 files changed

+38
-15
lines changed

3 files changed

+38
-15
lines changed

CHANGES.rst

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ Changes
2424
- Progress messages when generating a ``TableReport`` are now written to stderr instead of stdout.
2525
:pr:`1236` by :user:`Priscilla Baah<priscilla-b>`
2626

27+
- Optimize the :class:`StringEncoder`: significant memory reduction and 1.5x speed-up.
28+
:pr:`1248` by :user:`Gaël Varoquaux <gaelvaroquaux>`
29+
2730
Release 0.5.1
2831
=============
2932

skrub/_string_encoder.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class StringEncoder(SingleColumnTransformer):
2727
n_components : int, default=30
2828
Number of components to be used for the singular value decomposition (SVD).
2929
Must be a positive integer.
30-
vectorizer : str, "tfidf" or "hashing"
30+
vectorizer : str, "tfidf" or "hashing", default="tfidf"
3131
Vectorizer to apply to the strings, either `tfidf` or `hashing` for
3232
scikit-learn TfidfVectorizer or HashingVectorizer respectively.
3333
@@ -133,12 +133,17 @@ def fit_transform(self, X, y=None):
133133
f" 'hashing', got {self.vectorizer!r}"
134134
)
135135

136-
X = sbd.fill_nulls(X, "")
137-
X_out = self.vectorizer_.fit_transform(X)
136+
X_filled = sbd.fill_nulls(X, "")
137+
X_out = self.vectorizer_.fit_transform(X_filled).astype("float32")
138+
del X_filled # optimizes memory: we no longer need X
138139

139-
if (min_shape := min(X_out.shape)) >= self.n_components:
140-
self.tsvd_ = TruncatedSVD(n_components=self.n_components)
140+
if (min_shape := min(X_out.shape)) > self.n_components:
141+
self.tsvd_ = TruncatedSVD(
142+
n_components=self.n_components, algorithm="arpack"
143+
)
141144
result = self.tsvd_.fit_transform(X_out)
145+
elif X_out.shape[1] == self.n_components:
146+
result = X_out.toarray()
142147
else:
143148
warnings.warn(
144149
f"The matrix shape is {(X_out.shape)}, and its minimum is "
@@ -152,6 +157,8 @@ def fit_transform(self, X, y=None):
152157
# Therefore, self.n_components_ below stores the resulting
153158
# number of dimensions of result.
154159
result = X_out[:, : self.n_components].toarray()
160+
result = result.copy() # To avoid a reference to X_out
161+
del X_out # optimize memory: we no longer need X_out
155162

156163
self._is_fitted = True
157164
self.n_components_ = result.shape[1]
@@ -177,12 +184,15 @@ def transform(self, X):
177184
The embedding representation of the input.
178185
"""
179186

180-
X = sbd.fill_nulls(X, "")
181-
X_out = self.vectorizer_.transform(X)
187+
X_filled = sbd.fill_nulls(X, "")
188+
X_out = self.vectorizer_.transform(X_filled).astype("float32")
189+
del X_filled # optimizes memory: we no longer need X
182190
if hasattr(self, "tsvd_"):
183191
result = self.tsvd_.transform(X_out)
184192
else:
185193
result = X_out[:, : self.n_components].toarray()
194+
result = result.copy()
195+
del X_out # optimize memory: we no longer need X_out
186196

187197
return self._post_process(X, result)
188198

skrub/tests/test_string_encoder.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from numpy.testing import assert_almost_equal
23
from sklearn.base import clone
34
from sklearn.decomposition import TruncatedSVD
45
from sklearn.feature_extraction.text import (
@@ -37,6 +38,7 @@ def test_tfidf_vectorizer(encode_column, df_module):
3738
]
3839
)
3940
check = pipe.fit_transform(sbd.to_numpy(encode_column))
41+
check = check.astype("float32") # StringEncoder is float32
4042

4143
names = [f"col1_{idx}" for idx in range(2)]
4244

@@ -191,27 +193,35 @@ def test_n_components(df_module):
191193
assert encoder_30.n_components_ == 30
192194

193195

196+
def test_n_components_equal_voc_size(df_module):
197+
x = df_module.make_column("x", ["aab", "bba"])
198+
encoder = StringEncoder(n_components=2, ngram_range=(1, 1), analyzer="char")
199+
out = encoder.fit_transform(x)
200+
assert sbd.column_names(out) == ["x_0", "x_1"]
201+
assert not hasattr(encoder, "tsvd_")
202+
203+
194204
@pytest.mark.parametrize("vectorizer", ["tfidf", "hashing"])
195205
def test_missing_values(df_module, vectorizer):
196206
col = df_module.make_column("col", ["one two", None, "", "two three"])
197207
encoder = StringEncoder(n_components=2, vectorizer=vectorizer)
198208
out = encoder.fit_transform(col)
199209
for c in sbd.to_column_list(out):
200-
assert c[1] == 0.0
201-
assert c[2] == 0.0
210+
assert_almost_equal(c[1], 0.0, decimal=6)
211+
assert_almost_equal(c[2], 0.0, decimal=6)
202212
out = encoder.transform(col)
203213
for c in sbd.to_column_list(out):
204-
assert c[1] == 0.0
205-
assert c[2] == 0.0
214+
assert_almost_equal(c[1], 0.0, decimal=6)
215+
assert_almost_equal(c[2], 0.0, decimal=6)
206216
tv = TableVectorizer(
207217
low_cardinality=StringEncoder(n_components=2, vectorizer=vectorizer)
208218
)
209219
df = df_module.make_dataframe({"col": col})
210220
out = tv.fit_transform(df)
211221
for c in sbd.to_column_list(out):
212-
assert c[1] == 0.0
213-
assert c[2] == 0.0
222+
assert_almost_equal(c[1], 0.0, decimal=6)
223+
assert_almost_equal(c[2], 0.0, decimal=6)
214224
out = tv.transform(df)
215225
for c in sbd.to_column_list(out):
216-
assert c[1] == 0.0
217-
assert c[2] == 0.0
226+
assert_almost_equal(c[1], 0.0, decimal=6)
227+
assert_almost_equal(c[2], 0.0, decimal=6)

0 commit comments

Comments
 (0)