Skip to content

Commit 08b79df

Browse files
committed
add drop_axis option to columnselector
1 parent 472b1e3 commit 08b79df

File tree

4 files changed

+28
-23
lines changed

4 files changed

+28
-23
lines changed

docs/sources/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ The CHANGELOG for the current development version is available at
2525
- New function implementing the 5x2cv paired t-test procedure (`paired_ttest_5x2cv`) proposed by Dieterrich (1998)
2626
to compare the performance of two models. ([#325](https://github.com/rasbt/mlxtend/issues/325))
2727
- A `refit` parameter was added to stacking classes (similar to the `refit` parameter in the `EnsembleVoteClassifier`), to support classifiers and regressors that follow the scikit-learn API but are not compatible with scikit-learn's `clone` function ([#325](https://github.com/rasbt/mlxtend/issues/324))
28+
- The `ColumnSelector` now has a `drop_axis` argument to use it in pipelines with `CountVectorizers` ([#333](https://github.com/rasbt/mlxtend/pull/333)
2829

2930
##### Changes
3031

docs/sources/user_guide/feature_selection/ColumnSelector.ipynb

+5-20
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
{
44
"cell_type": "code",
55
"execution_count": 1,
6-
"metadata": {
7-
"collapsed": true
8-
},
6+
"metadata": {},
97
"outputs": [],
108
"source": [
119
"%matplotlib inline"
@@ -72,9 +70,7 @@
7270
{
7371
"cell_type": "code",
7472
"execution_count": 2,
75-
"metadata": {
76-
"collapsed": true
77-
},
73+
"metadata": {},
7874
"outputs": [],
7975
"source": [
8076
"from sklearn.datasets import load_iris\n",
@@ -183,9 +179,7 @@
183179
{
184180
"cell_type": "code",
185181
"execution_count": 5,
186-
"metadata": {
187-
"collapsed": true
188-
},
182+
"metadata": {},
189183
"outputs": [],
190184
"source": [
191185
"from sklearn.datasets import load_iris\n",
@@ -283,7 +277,7 @@
283277
"text": [
284278
"## ColumnSelector\n",
285279
"\n",
286-
"*ColumnSelector(cols=None)*\n",
280+
"*ColumnSelector(cols=None, drop_axis=False)*\n",
287281
"\n",
288282
"Base class for all estimators in scikit-learn\n",
289283
"\n",
@@ -402,15 +396,6 @@
402396
" s = f.read() + '<br><br>'\n",
403397
"print(s)"
404398
]
405-
},
406-
{
407-
"cell_type": "code",
408-
"execution_count": null,
409-
"metadata": {
410-
"collapsed": true
411-
},
412-
"outputs": [],
413-
"source": []
414399
}
415400
],
416401
"metadata": {
@@ -430,7 +415,7 @@
430415
"name": "python",
431416
"nbconvert_exporter": "python",
432417
"pygments_lexer": "ipython3",
433-
"version": "3.6.1"
418+
"version": "3.6.4"
434419
}
435420
},
436421
"nbformat": 4,

mlxtend/feature_selection/column_selector.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
class ColumnSelector(BaseEstimator):
1515

16-
def __init__(self, cols=None):
16+
def __init__(self, cols=None, drop_axis=False):
1717
"""Object for selecting specific columns from a data set.
1818
1919
Parameters
@@ -22,8 +22,18 @@ def __init__(self, cols=None):
2222
A list specifying the feature indices to be selected. For example,
2323
[1, 4, 5] to select the 2nd, 5th, and 6th feature columns.
2424
If None, returns all columns in the array.
25+
26+
drop_axis : bool (default=False)
27+
Drops last axis if True and the only one column is selected. This
28+
is useful, e.g., when the ColumnSelector is used for selecting
29+
only one column and the resulting array should be fed to e.g.,
30+
a scikit-learn column selector. E.g., instead of returning an
31+
array with shape (n_samples, 1), drop_axis=True will return an
32+
aray with shape (n_samples,).
33+
2534
"""
2635
self.cols = cols
36+
self.drop_axis = drop_axis
2737

2838
def fit_transform(self, X, y=None):
2939
""" Return a slice of the input array.
@@ -60,7 +70,7 @@ def transform(self, X, y=None):
6070
6171
"""
6272
t = X[:, self.cols]
63-
if len(t.shape) == 1:
73+
if len(t.shape) == 1 and not self.drop_axis:
6474
t = t[:, np.newaxis]
6575
return t
6676

mlxtend/feature_selection/tests/test_column_selector.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,16 @@
1717
def test_ColumnSelector():
1818
X1_in = np.ones((4, 8))
1919
X1_out = ColumnSelector(cols=(1, 3)).transform(X1_in)
20-
assert(X1_out.shape == (4, 2))
20+
assert X1_out.shape == (4, 2)
21+
22+
23+
def test_ColumnSelector_drop_axis():
24+
X1_in = np.ones((4, 8))
25+
X1_out = ColumnSelector(cols=(1), drop_axis=True).transform(X1_in)
26+
assert X1_out.shape == (4,)
27+
28+
X1_out = ColumnSelector(cols=(1)).transform(X1_in)
29+
assert X1_out.shape == (4, 1)
2130

2231

2332
def test_ColumnSelector_in_gridsearch():

0 commit comments

Comments
 (0)