-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathadjust_labels.py
52 lines (36 loc) · 1.66 KB
/
adjust_labels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""Simple class to make class labels binary. Useful if a network uses BCELoss for example."""
from __future__ import annotations
import numpy as np
import pandas as pd
from ethicml.utility import DataTuple
__all__ = ["LabelBinarizer"]
def assert_binary_labels(data_tuple: DataTuple) -> None:
"""Assert that datasets only include binary labels."""
assert data_tuple.y.nunique() == 2
assert (np.unique(data_tuple.y.to_numpy()) == np.array([0, 1])).all()
class LabelBinarizer:
"""If a dataset has labels [-1,1], then this will make it so the labels = [0,1]."""
min_val: int
max_val: int
def adjust(self, dataset: DataTuple) -> DataTuple:
"""Take a datatuple and make the labels [0,1]."""
assert dataset.y.nunique() == 2
# make copy of dataset
new_y = dataset.y.copy()
self.min_val = new_y.to_numpy().min().item()
self.max_val = new_y.to_numpy().max().item()
new_y = new_y.replace(self.min_val, 0)
new_y = new_y.replace(self.max_val, 1)
return dataset.replace(y=new_y)
def post_only_labels(self, labels: pd.Series) -> pd.Series:
"""Inverse of adjust but only for a DataFrame instead of a DataTuple."""
assert labels.nunique() == 2
# make copy of the labels
labels_copy = labels.copy()
labels_copy = labels_copy.replace(0, self.min_val)
labels_copy = labels_copy.replace(1, self.max_val)
return labels_copy
def post(self, dataset: DataTuple) -> DataTuple:
"""Inverse of adjust."""
transformed_y = self.post_only_labels(dataset.y)
return dataset.replace(y=pd.Series(transformed_y, name=dataset.y.name))