-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmy_stats_spark.py
85 lines (57 loc) · 2.4 KB
/
my_stats_spark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#
# USAGE: $SPARK_HOME/bin/spark-submit my_annotations.py
#
from pyspark.mllib.stat import Statistics
from pyspark import SparkContext
import scipy.stats
import pytest
sc = SparkContext('local', 'stats')
##need to declare CDF outsice of the class, otherwise spark tries to ship it with a closure
def CDF(a, b, l, x):
a_count = len(filter(lambda y: y <= x, a.value))
b_count = len(filter(lambda y: y <= x, b.value))
return abs(a_count - b_count) / l
class BasicStats(object):
def __init__(self, a, b):
self.a_driver = a
self.b_driver = b
self.a = sc.parallelize(a).map(lambda x: x + 0.0)
self.b = sc.parallelize(b).map(lambda x: x + 0.0)
self.length = len(a) + 0.0
def ks(self):
a_br = sc.broadcast(self.a_driver)
b_br = sc.broadcast(self.b_driver)
length = self.length
return self.a.union(self.b).distinct().map(lambda x: CDF(a_br, b_br, length, x)).max()
def pearson(self):
return Statistics.corr(self.a, self.b, 'pearson')
class AdvancedStats(BasicStats):
def __init__(self, a, b):
BasicStats.__init__(self, a, b)
def rho(self):
return Statistics.corr(self.a, self.b, 'spearman')
def tau(self):
zip = self.a.zip(self.b).zipWithIndex()
denominator = self.length * (self.length - 1) / 2.0
pairs = zip.cartesian(zip).filter(lambda x: x[0][1] > x[1][1])
differences = pairs.map(lambda x: (x[0][0][0] - x[1][0][0]) * (x[0][0][1] - x[1][0][1]))
numerator = differences.filter(lambda x: x != 0).map(lambda x: (1 if x > 0 else -1)).sum()
return numerator / denominator
def test_ks():
list1 = [1, 2, 3, 4, 5, 6]
list2 = [9, 8, 7, 6, 5, 4]
# make sure my implementation matches non-distributed scipy version
stats = BasicStats(list1, list2)
assert stats.ks() == scipy.stats.ks_2samp(list1, list2).statistic
def test_tau():
list1 = [0, 2, 4, 4, 3, 10]
list2 = [4, 5, 0, 3, 2, 1]
# make sure my implementation matches non-distributed scipy version
# there are some issues and it doesn't match on longer sequences
stats = AdvancedStats(list1, list2)
assert abs(stats.tau() - scipy.stats.kendalltau(list1, list2)[0]) < 0.02
if __name__ == '__main__':
test_ks()
test_tau()
# normally this would be started with "pytest.main([__file__])",
# but I'm not sure how pyspark and pytest would behave together