-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy patheval.py
29 lines (25 loc) · 943 Bytes
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import sys
import argparse
import vaetc
from vis import main as visualize_gwae
from clu import main as visualize_cluster
sys.path.append("./")
import models
sys.path.pop()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("logger_path", type=str)
parser.add_argument("--no-evaluate", action="store_true")
parser.add_argument("--no-quant", action="store_true")
parser.add_argument("--no-qual", action="store_true")
parser.add_argument("--no-gwae", action="store_true")
parser.add_argument("--no-cluster", action="store_true")
args = parser.parse_args()
cp = vaetc.load_checkpoint(os.path.join(args.logger_path, "checkpoint_last.pth"))
if not args.no_evaluate:
vaetc.evaluate(cp, qualitative=not args.no_qual, quantitative=not args.no_quant)
if not args.no_gwae:
visualize_gwae(cp)
if not args.no_cluster:
visualize_cluster(cp)