Skip to content

Commit 8eb6954

Browse files
author
Sebastian Raschka
authored
Merge pull request #721 from DarthTrevis/permutations
Permutations
2 parents ac50506 + 4bcc3e9 commit 8eb6954

File tree

5 files changed

+28
-22
lines changed

5 files changed

+28
-22
lines changed

docs/sources/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ The CHANGELOG for the current development version is available at
2323

2424
##### Changes
2525

26-
- -
26+
- `permutation_test` (`mlxtend.evaluate.permutation`) ìs corrected to give the proportion of permutations whose statistic is *at least as extreme* as the one observed. ([#721](https://github.com/rasbt/mlxtend/pull/721) via [Florian Charlier](https://github.com/DarthTrevis))
2727

2828
##### Bug Fixes
2929

docs/sources/user_guide/evaluate/permutation_test.ipynb

+6-6
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,13 @@
4343
"4. Divide the permuted dataset into two datasets x' and y' of size *n* and *m*, respectively\n",
4444
"5. Compute the difference (here: mean) of sample x' and sample y' and record this difference\n",
4545
"6. Repeat steps 3-5 until all permutations are evaluated\n",
46-
"7. Return the p-value as the number of times the recorded differences were more extreme than the original difference from 1. and divide this number by the total number of permutations\n",
46+
"7. Return the p-value as the number of times the recorded differences were at least as extreme as the original difference from 1. and divide this number by the total number of permutations\n",
4747
"\n",
4848
"Here, the p-value is defined as the probability, given the null hypothesis (no difference between the samples) is true, that we obtain results that are at least as extreme as the results we observed (i.e., the sample difference from 1.).\n",
4949
"\n",
50-
"More formally, we can express the computation of the p-value as follows ([2]):\n",
50+
"More formally, we can express the computation of the p-value as follows (adapted from [2]):\n",
5151
"\n",
52-
"$$p(t > t_0) = \\frac{1}{(n+m)!} \\sum^{(n+m)!}_{j=1} I(t_j > t_0),$$\n",
52+
"$$p(t \\geq t_0) = \\frac{1}{(n+m)!} \\sum^{(n+m)!}_{j=1} I(t_j \\geq t_0),$$\n",
5353
"\n",
5454
"where $t_0$ is the observed value of the test statistic (1. in the list above), and $t$ is the t-value, the statistic computed from the resamples (5.) $t(x'_1, x'_2, ..., x'_n, y'_1, y'_2, ..., x'_m) = |\\bar{x'} - \\bar{y'}|$, and *I* is the indicator function.\n",
5555
"\n",
@@ -114,7 +114,7 @@
114114
"name": "stdout",
115115
"output_type": "stream",
116116
"text": [
117-
"0.0066\n"
117+
"0.0066993300669933005\n"
118118
]
119119
}
120120
],
@@ -159,7 +159,7 @@
159159
"output_type": "stream",
160160
"text": [
161161
"Observed pearson R: 0.81\n",
162-
"P value: 0.09\n"
162+
"P value: 0.10\n"
163163
]
164164
}
165165
],
@@ -281,7 +281,7 @@
281281
"name": "python",
282282
"nbconvert_exporter": "python",
283283
"pygments_lexer": "ipython3",
284-
"version": "3.6.4"
284+
"version": "3.8.5"
285285
},
286286
"toc": {
287287
"nav_menu": {},

mlxtend/evaluate/permutation.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def func(x, y):
9797
m, n = len(x), len(y)
9898
combined = np.hstack((x, y))
9999

100-
more_extreme = 0.
100+
at_least_as_extreme = 0.
101101
reference_stat = func(x, y)
102102

103103
# Note that whether we compute the combinations or permutations
@@ -120,15 +120,21 @@ def func(x, y):
120120
indices_y = [i for i in range(m + n) if i not in indices_x]
121121
diff = func(combined[list(indices_x)], combined[indices_y])
122122

123-
if diff > reference_stat:
124-
more_extreme += 1.
123+
if diff > reference_stat or np.isclose(diff, reference_stat):
124+
at_least_as_extreme += 1.
125125

126126
num_rounds = factorial(m + n) / (factorial(m)*factorial(n))
127127

128128
else:
129129
for i in range(num_rounds):
130130
rng.shuffle(combined)
131-
if func(combined[:m], combined[m:]) > reference_stat:
132-
more_extreme += 1.
131+
diff = func(combined[:m], combined[m:])
133132

134-
return more_extreme / num_rounds
133+
if diff > reference_stat or np.isclose(diff, reference_stat):
134+
at_least_as_extreme += 1.
135+
136+
# To cover the actual experiment results
137+
at_least_as_extreme += 1
138+
num_rounds += 1
139+
140+
return at_least_as_extreme / num_rounds

mlxtend/evaluate/tests/test_permutation.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -16,36 +16,36 @@
1616
def test_one_sided_x_greater_y():
1717
p = permutation_test(treatment, control,
1818
func=lambda x, y: np.mean(x) - np.mean(y))
19-
assert round(p, 4) == 0.0274, p
19+
assert round(p, 4) == 0.0301, p
2020

2121
p = permutation_test(treatment, control,
2222
func="x_mean > y_mean")
23-
assert round(p, 4) == 0.0274, p
23+
assert round(p, 4) == 0.0301, p
2424

2525

2626
def test_one_sided_y_greater_x():
2727
p = permutation_test(treatment, control,
2828
func=lambda x, y: np.mean(y) - np.mean(x))
29-
assert round(p, 3) == 1 - 0.03, p
29+
assert round(p, 3) == 0.973, p
3030

3131
p = permutation_test(treatment, control,
3232
func="x_mean < y_mean")
33-
assert round(p, 3) == 1 - 0.03, p
33+
assert round(p, 3) == 0.973, p
3434

3535

3636
def test_two_sided():
3737
p = permutation_test(treatment, control,
3838
func=lambda x, y: np.abs(np.mean(x) - np.mean(y)))
39-
assert round(p, 3) == 0.055, p
39+
assert round(p, 3) == 0.060, p
4040

4141
p = permutation_test(treatment, control,
4242
func="x_mean != y_mean")
43-
assert round(p, 3) == 0.055, p
43+
assert round(p, 3) == 0.060, p
4444

4545

4646
def test_default():
4747
p = permutation_test(treatment, control)
48-
assert round(p, 3) == 0.055, p
48+
assert round(p, 3) == 0.060, p
4949

5050

5151
def test_approximateone_sided_x_greater_y():
@@ -55,7 +55,7 @@ def test_approximateone_sided_x_greater_y():
5555
method='approximate',
5656
num_rounds=5000,
5757
seed=123)
58-
assert round(p, 3) == 0.028, p
58+
assert round(p, 3) == 0.031, p
5959

6060

6161
def test_invalid_method():

mlxtend/feature_extraction/tests/test_principal_component_analysis.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_fail_array_dimension_2():
108108
def test_variance_explained_ratio():
109109
pca = PCA()
110110
pca.fit(X_std)
111-
assert np.sum(pca.e_vals_normalized_) == 1.
111+
assert_almost_equal(np.sum(pca.e_vals_normalized_), 1.)
112112
assert np.sum(pca.e_vals_normalized_ < 0.) == 0
113113

114114

0 commit comments

Comments
 (0)