1
1
"""Independence of 2 variables."""
2
2
3
- from typing import Type
3
+ from collections . abc import Callable
4
4
5
5
import numpy as np
6
6
import torch
9
9
from .density_estimation import Kde
10
10
11
11
12
- def _joint_2 (x : Tensor , y : Tensor , density : Type [Kde ], damping : float = 1e-10 ) -> Tensor :
12
+ def _joint_2 (
13
+ x : Tensor , y : Tensor , density : Callable [[Tensor ], Kde ], damping : float = 1e-10
14
+ ) -> Tensor :
13
15
x = (x - x .mean ()) / x .std ()
14
16
y = (y - y .mean ()) / y .std ()
15
17
data = torch .cat ([x .unsqueeze (- 1 ), y .unsqueeze (- 1 )], - 1 )
@@ -27,7 +29,7 @@ def _joint_2(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 1e-10) -
27
29
return h2d
28
30
29
31
30
- def hgr (x : Tensor , y : Tensor , density : Type [ Kde ], damping : float = 1e-10 ) -> Tensor :
32
+ def hgr (x : Tensor , y : Tensor , density : Callable [[ Tensor ], Kde ], damping : float = 1e-10 ) -> Tensor :
31
33
"""An estimator of the Hirschfeld-Gebelein-Renyi maximum correlation coefficient.
32
34
33
35
This function is using Witsenhausen’s Characterization.
@@ -48,7 +50,7 @@ def hgr(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 1e-10) -> Ten
48
50
return torch .svd (Q )[1 ][1 ]
49
51
50
52
51
- def chi_2 (x : Tensor , y : Tensor , density : Type [ Kde ], damping : float = 0 ) -> Tensor :
53
+ def chi_2 (x : Tensor , y : Tensor , density : Callable [[ Tensor ], Kde ], damping : float = 0 ) -> Tensor :
52
54
r"""The :math:`\chi^2` divergence between the joint distribution and the product of marginals.
53
55
54
56
This is know to be the square of an upper-bound on the Hirschfeld-Gebelein-Renyi maximum
@@ -71,7 +73,9 @@ def chi_2(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 0) -> Tenso
71
73
# Independence of conditional variables
72
74
73
75
74
- def _joint_3 (x : Tensor , y : Tensor , z : Tensor , density : Type [Kde ], damping : float = 1e-10 ) -> Tensor :
76
+ def _joint_3 (
77
+ x : Tensor , y : Tensor , z : Tensor , density : Callable [[Tensor ], Kde ], damping : float = 1e-10
78
+ ) -> Tensor :
75
79
x = (x - x .mean ()) / x .std ()
76
80
y = (y - y .mean ()) / y .std ()
77
81
z = (z - z .mean ()) / z .std ()
@@ -90,7 +94,7 @@ def _joint_3(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde], damping: float
90
94
return h3d
91
95
92
96
93
- def hgr_cond (x : Tensor , y : Tensor , z : Tensor , density : Type [ Kde ]) -> np .ndarray :
97
+ def hgr_cond (x : Tensor , y : Tensor , z : Tensor , density : Callable [[ Tensor ], Kde ]) -> np .ndarray :
94
98
r"""An estimator of the function :math:`z\to HGR(x|z, y|z)`.
95
99
96
100
Where HGR is the Hirschfeld-Gebelein-Renyi maximum correlation
@@ -113,7 +117,7 @@ def hgr_cond(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde]) -> np.ndarray:
113
117
return np .array ([torch .svd (Q [:, :, i ])[1 ][1 ] for i in range (Q .shape [2 ])])
114
118
115
119
116
- def chi_2_cond (x : Tensor , y : Tensor , z : Tensor , density : Type [ Kde ]) -> Tensor :
120
+ def chi_2_cond (x : Tensor , y : Tensor , z : Tensor , density : Callable [[ Tensor ], Kde ]) -> Tensor :
117
121
r"""An estimator of the function :math:`z\to chi^2(x|z, y|z)`.
118
122
119
123
Where :math:`\chi^2` is the :math:`\chi^2` divergence between the joint distribution on (x,y)
0 commit comments