Skip to content

Commit bf8c4fd

Browse files
authored
Don't use type[T]; use Callable[..., T] (#1073)
It's much more type-safe.
2 parents ce8cedd + bbbf23f commit bf8c4fd

File tree

7 files changed

+46
-28
lines changed

7 files changed

+46
-28
lines changed

ethicml/implementations/dro_modules/dro_loss.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""DRO Loss."""
22

3-
from typing import Type
3+
from typing import Protocol
44
from typing_extensions import override
55

66
from torch import Tensor, nn
@@ -9,10 +9,14 @@
99
__all__ = ["DROLoss"]
1010

1111

12+
class LossFactory(Protocol):
13+
def __call__(self, *, reduction: str = "mean") -> _Loss: ...
14+
15+
1216
class DROLoss(nn.Module):
1317
"""Fairness Without Demographics Loss."""
1418

15-
def __init__(self, loss_module: Type[_Loss] | None = None, eta: float = 0.5):
19+
def __init__(self, loss_module: LossFactory | None = None, eta: float = 0.5):
1620
super().__init__()
1721
if loss_module is None:
1822
loss_module = NLLLoss

ethicml/implementations/hgr_modules/facl/facl_hgr.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Independence of 2 variables."""
22

3-
from typing import Type
3+
from collections.abc import Callable
44

55
import numpy as np
66
import torch
@@ -9,7 +9,9 @@
99
from .density_estimation import Kde
1010

1111

12-
def _joint_2(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 1e-10) -> Tensor:
12+
def _joint_2(
13+
x: Tensor, y: Tensor, density: Callable[[Tensor], Kde], damping: float = 1e-10
14+
) -> Tensor:
1315
x = (x - x.mean()) / x.std()
1416
y = (y - y.mean()) / y.std()
1517
data = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], -1)
@@ -27,7 +29,7 @@ def _joint_2(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 1e-10) -
2729
return h2d
2830

2931

30-
def hgr(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 1e-10) -> Tensor:
32+
def hgr(x: Tensor, y: Tensor, density: Callable[[Tensor], Kde], damping: float = 1e-10) -> Tensor:
3133
"""An estimator of the Hirschfeld-Gebelein-Renyi maximum correlation coefficient.
3234
3335
This function is using Witsenhausen’s Characterization.
@@ -48,7 +50,7 @@ def hgr(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 1e-10) -> Ten
4850
return torch.svd(Q)[1][1]
4951

5052

51-
def chi_2(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 0) -> Tensor:
53+
def chi_2(x: Tensor, y: Tensor, density: Callable[[Tensor], Kde], damping: float = 0) -> Tensor:
5254
r"""The :math:`\chi^2` divergence between the joint distribution and the product of marginals.
5355
5456
This is know to be the square of an upper-bound on the Hirschfeld-Gebelein-Renyi maximum
@@ -71,7 +73,9 @@ def chi_2(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 0) -> Tenso
7173
# Independence of conditional variables
7274

7375

74-
def _joint_3(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde], damping: float = 1e-10) -> Tensor:
76+
def _joint_3(
77+
x: Tensor, y: Tensor, z: Tensor, density: Callable[[Tensor], Kde], damping: float = 1e-10
78+
) -> Tensor:
7579
x = (x - x.mean()) / x.std()
7680
y = (y - y.mean()) / y.std()
7781
z = (z - z.mean()) / z.std()
@@ -90,7 +94,7 @@ def _joint_3(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde], damping: float
9094
return h3d
9195

9296

93-
def hgr_cond(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde]) -> np.ndarray:
97+
def hgr_cond(x: Tensor, y: Tensor, z: Tensor, density: Callable[[Tensor], Kde]) -> np.ndarray:
9498
r"""An estimator of the function :math:`z\to HGR(x|z, y|z)`.
9599
96100
Where HGR is the Hirschfeld-Gebelein-Renyi maximum correlation
@@ -113,7 +117,7 @@ def hgr_cond(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde]) -> np.ndarray:
113117
return np.array([torch.svd(Q[:, :, i])[1][1] for i in range(Q.shape[2])])
114118

115119

116-
def chi_2_cond(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde]) -> Tensor:
120+
def chi_2_cond(x: Tensor, y: Tensor, z: Tensor, density: Callable[[Tensor], Kde]) -> Tensor:
117121
r"""An estimator of the function :math:`z\to chi^2(x|z, y|z)`.
118122
119123
Where :math:`\chi^2` is the :math:`\chi^2` divergence between the joint distribution on (x,y)

ethicml/implementations/hgr_modules/facl_hgr.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Independence of 2 variables."""
22

3-
from typing import Type
3+
from collections.abc import Callable
44

55
import numpy as np
66
import torch
@@ -9,7 +9,9 @@
99
from .density_estimation import Kde
1010

1111

12-
def _joint_2(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 1e-10) -> Tensor:
12+
def _joint_2(
13+
x: Tensor, y: Tensor, density: Callable[[Tensor], Kde], damping: float = 1e-10
14+
) -> Tensor:
1315
x = (x - x.mean()) / x.std()
1416
y = (y - y.mean()) / y.std()
1517
data = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], -1)
@@ -27,7 +29,7 @@ def _joint_2(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 1e-10) -
2729
return h2d
2830

2931

30-
def hgr(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 1e-10) -> Tensor:
32+
def hgr(x: Tensor, y: Tensor, density: Callable[[Tensor], Kde], damping: float = 1e-10) -> Tensor:
3133
"""An estimator of the Hirschfeld-Gebelein-Renyi maximum correlation coefficient.
3234
3335
This function is using Witsenhausen’s Characterization.
@@ -48,7 +50,7 @@ def hgr(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 1e-10) -> Ten
4850
return torch.svd(Q)[1][1]
4951

5052

51-
def chi_2(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 0) -> Tensor:
53+
def chi_2(x: Tensor, y: Tensor, density: Callable[[Tensor], Kde], damping: float = 0) -> Tensor:
5254
r"""The :math:`\chi^2` divergence between the joint distribution and the product of marginals.
5355
5456
This is know to be the square of an upper-bound on the Hirschfeld-Gebelein-Renyi maximum
@@ -71,7 +73,9 @@ def chi_2(x: Tensor, y: Tensor, density: Type[Kde], damping: float = 0) -> Tenso
7173
# Independence of conditional variables
7274

7375

74-
def _joint_3(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde], damping: float = 1e-10) -> Tensor:
76+
def _joint_3(
77+
x: Tensor, y: Tensor, z: Tensor, density: Callable[[Tensor], Kde], damping: float = 1e-10
78+
) -> Tensor:
7579
x = (x - x.mean()) / x.std()
7680
y = (y - y.mean()) / y.std()
7781
z = (z - z.mean()) / z.std()
@@ -90,7 +94,7 @@ def _joint_3(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde], damping: float
9094
return h3d
9195

9296

93-
def hgr_cond(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde]) -> np.ndarray:
97+
def hgr_cond(x: Tensor, y: Tensor, z: Tensor, density: Callable[[Tensor], Kde]) -> np.ndarray:
9498
r"""An estimator of the function :math:`z\to HGR(x|z, y|z)`.
9599
96100
Where HGR is the Hirschfeld-Gebelein-Renyi maximum correlation
@@ -113,7 +117,7 @@ def hgr_cond(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde]) -> np.ndarray:
113117
return np.array([torch.svd(Q[:, :, i])[1][1] for i in range(Q.shape[2])])
114118

115119

116-
def chi_2_cond(x: Tensor, y: Tensor, z: Tensor, density: Type[Kde]) -> Tensor:
120+
def chi_2_cond(x: Tensor, y: Tensor, z: Tensor, density: Callable[[Tensor], Kde]) -> Tensor:
117121
r"""An estimator of the function :math:`z\to chi^2(x|z, y|z)`.
118122
119123
Where :math:`\chi^2` is the :math:`\chi^2` divergence between the joint distribution on (x,y)

ethicml/run/cross_validator.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Cross Validation for any in process (at the moment) Algorithm."""
22

33
from collections import defaultdict
4+
from collections.abc import Callable
45
from itertools import product
56
from statistics import mean
6-
from typing import Any, Mapping, NamedTuple, Sequence, Type
7+
from typing import Any, Mapping, NamedTuple, Sequence
78

89
from ethicml.metrics.accuracy import Accuracy
910
from ethicml.metrics.cv import AbsCV
@@ -55,7 +56,7 @@ class CVResults:
5556
5657
"""
5758

58-
def __init__(self, results: list[ResultTuple], model: type[InAlgorithm]):
59+
def __init__(self, results: list[ResultTuple], model: Callable[..., InAlgorithm]):
5960
self.raw_storage = results
6061
self.model = model
6162
self.mean_storage = self._organize_and_compute_means()
@@ -195,7 +196,7 @@ class CrossValidator:
195196

196197
def __init__(
197198
self,
198-
model: Type[InAlgorithm],
199+
model: Callable[..., InAlgorithm],
199200
hyperparams: Mapping[str, Sequence[Any]],
200201
folds: int = 3,
201202
max_parallel: int = 0,

tests/data/dataset_modification_test.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Test modifiactions to a dataset."""
22

3-
from typing import Type, Union
3+
from collections.abc import Callable
4+
from typing import Union
45

56
import pandas as pd
67
import pytest
@@ -14,7 +15,8 @@
1415
@pytest.mark.parametrize("dataset_name", available_tabular())
1516
@pytest.mark.parametrize("scaler_type", [StandardScaler, MinMaxScaler])
1617
def test_scaling(
17-
dataset_name: str, scaler_type: Union[Type[StandardScaler], Type[MinMaxScaler]]
18+
dataset_name: str,
19+
scaler_type: Union[Callable[[], StandardScaler], Callable[[], StandardScaler]],
1820
) -> None:
1921
"""Test that scaling works."""
2022
scaler = scaler_type()
@@ -42,7 +44,8 @@ def test_scaling(
4244
@pytest.mark.parametrize("dataset_name", available_tabular())
4345
@pytest.mark.parametrize("scaler_type", [StandardScaler, MinMaxScaler])
4446
def test_scaling_separate_test(
45-
dataset_name: str, scaler_type: Union[Type[StandardScaler], Type[MinMaxScaler]]
47+
dataset_name: str,
48+
scaler_type: Union[Callable[[], StandardScaler], Callable[[], StandardScaler]],
4649
) -> None:
4750
"""Test that scaling works."""
4851
scaler = scaler_type()

tests/models_test/inprocess_test/cv_test.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests for cross validation."""
22

3-
from typing import Dict, List, NamedTuple, Sequence, Type, Union
3+
from collections.abc import Callable
4+
from typing import Dict, List, NamedTuple, Sequence, Union
45

56
import numpy as np
67
import pytest
@@ -14,7 +15,7 @@
1415
class CvParam(NamedTuple):
1516
"""Specification of a unit test for cross validation."""
1617

17-
model: Type[InAlgorithm]
18+
model: Callable[..., InAlgorithm]
1819
hyperparams: Dict[str, Union[Sequence[float], List[str], Sequence[KernelType]]]
1920
num_pos: int
2021

@@ -32,7 +33,7 @@ class CvParam(NamedTuple):
3233
@pytest.mark.parametrize(("model", "hyperparams", "num_pos"), CV_PARAMS)
3334
def test_cv(
3435
toy_train_test: TrainTestPair,
35-
model: Type[InAlgorithm],
36+
model: Callable[..., InAlgorithm],
3637
hyperparams: Dict[str, Union[Sequence[float], List[str]]],
3738
num_pos: int,
3839
) -> None:
@@ -53,7 +54,7 @@ def test_cv(
5354
@pytest.mark.parametrize(("model", "hyperparams", "num_pos"), CV_PARAMS)
5455
def test_parallel_cv(
5556
toy_train_test: TrainTestPair,
56-
model: Type[InAlgorithm],
57+
model: Callable[..., InAlgorithm],
5758
hyperparams: Dict[str, Union[Sequence[float], List[str]]],
5859
num_pos: int,
5960
) -> None:

tests/models_test/inprocess_test/zafar_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
They are kept separate because they're very slow.
44
"""
55

6-
from typing import Dict, Generator, List, Type
6+
from collections.abc import Callable
7+
from typing import Dict, Generator, List
78

89
import numpy as np
910
import pytest
@@ -61,7 +62,7 @@ def test_zafar(toy_train_test: TrainTestPair) -> None: # noqa: PLR0915
6162

6263
hyperparams: Dict[str, List[float]] = {"gamma": [1, 1e-1, 1e-2]}
6364

64-
model_class: Type[InAlgorithm] = ZafarAccuracy
65+
model_class: Callable[..., InAlgorithm] = ZafarAccuracy
6566
zafar_cv = CrossValidator(model_class, hyperparams, folds=3)
6667

6768
assert zafar_cv is not None

0 commit comments

Comments
 (0)