-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathfair_dummies_romano.py
86 lines (71 loc) · 3.01 KB
/
fair_dummies_romano.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Fair Dummies Implementation."""
import json
from pathlib import Path
import random
import sys
from typing import TYPE_CHECKING
from joblib import dump, load
import numpy as np
import torch
from ethicml.implementations.fair_dummies_modules.model import EquiClassLearner
from ethicml.utility import DataTuple, ModelType, SoftPrediction, SubgroupTuple, TestTuple
if TYPE_CHECKING:
from ethicml.models.inprocess.fair_dummies import FairDummiesArgs
from ethicml.models.inprocess.in_subprocess import InAlgoArgs
def fit(train: DataTuple, args: "FairDummiesArgs", seed: int = 888) -> EquiClassLearner:
"""Fit a model."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
model = EquiClassLearner(
lr=args["lr"],
pretrain_pred_epochs=args["pretrain_pred_epochs"],
pretrain_dis_epochs=args["pretrain_dis_epochs"],
epochs=args["epochs"],
loss_steps=args["loss_steps"],
dis_steps=args["dis_steps"],
cost_pred=torch.nn.CrossEntropyLoss(),
in_shape=len(train.x.columns),
batch_size=args["batch_size"],
model_type=ModelType(args["model_type"]),
lambda_vec=args["lambda_vec"],
second_moment_scaling=args["second_moment_scaling"],
num_classes=train.y.nunique(),
seed=seed,
)
return model.fit(train, seed=seed)
def predict(model: EquiClassLearner, test: TestTuple) -> np.ndarray:
"""Compute predictions on the given test data."""
return model.predict(test.x)
def train_and_predict(
train: DataTuple, test: TestTuple, args: "FairDummiesArgs", seed: int
) -> np.ndarray:
"""Train a logistic regression model and compute predictions on the given test data."""
model = fit(train, args, seed)
return predict(model, test)
def main() -> None:
"""Run the Agarwal model as a standalone program."""
in_algo_args: InAlgoArgs = json.loads(sys.argv[1])
flags: FairDummiesArgs = json.loads(sys.argv[2])
if in_algo_args["mode"] == "run":
train = DataTuple.from_file(Path(in_algo_args["train"]))
test = SubgroupTuple.from_file(Path(in_algo_args["test"]))
SoftPrediction(
soft=train_and_predict(train, test, flags, in_algo_args["seed"])
).save_to_file(Path(in_algo_args["predictions"]))
elif in_algo_args["mode"] == "fit":
data = DataTuple.from_file(Path(in_algo_args["train"]))
model = fit(data, flags, in_algo_args["seed"])
# need to save the seed as well
model.ethicml_random_seed = in_algo_args["seed"] # type: ignore
dump(model, Path(in_algo_args["model"]))
elif in_algo_args["mode"] == "predict":
test = SubgroupTuple.from_file(Path(in_algo_args["test"]))
model = load(Path(in_algo_args["model"]))
SoftPrediction(soft=predict(model, test)).save_to_file(Path(in_algo_args["predictions"]))
else:
raise RuntimeError(f"Unknown mode: {in_algo_args['mode']}")
if __name__ == "__main__":
main()