-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathis.py
100 lines (76 loc) · 3.22 KB
/
is.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import torch
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms
from scipy.stats import entropy
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models.inception import inception_v3
def inception_score(gan_loader: DataLoader, batch_size: int) -> float:
r"""Calculates inception score using Inception_v3.
Args:
-gan_loader (DataLoader): Loader for GAN generated images.
-batch_size (int): Size of batch for Inception inference.
Returns:
-score (float): Inception score.
"""
n_imgs = len(gan_loader) * batch_size # loader has drop_last=True
n_classes = 1000 # number of ImageNet classes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Inception expects input to be in shape (299, 299)
upsampler = nn.Upsample(size=(299, 299),
mode="bilinear",
align_corners=True)
model = inception_v3(pretrained=True).to(device)
model.eval()
# keeps predictions for each image
preds = np.zeros((n_imgs, n_classes))
for idx, (batch, _) in enumerate(gan_loader):
with torch.no_grad():
batch = batch.to(device)
batch = upsampler(batch)
batch_preds = model(batch)
batch_preds = nn.functional.softmax(batch_preds, dim=1)
batch_preds = batch_preds.cpu().numpy()
preds[idx * batch_size: (idx + 1) * batch_size] = batch_preds
# marginalizing all pyx
py = np.mean(preds, axis=0)
scores = np.zeros((n_imgs, ))
for i, pyx in enumerate(preds):
scores[i] = entropy(pyx, py)
return np.exp(np.mean(scores))
def get_loader(path: str, img_size: int, batch_size: int) -> DataLoader:
r"""Creates DataLoader instance for GAN generated images.
Args:
-path (str): Path of images generated by GAN.
-img_size (int): Size of image that GAN is producing.
-batch_size (int): Size of batch for Inception inference.
Returns:
-gan_loader (DataLoader): Loader for GAN generated images.
"""
def get_transform():
transform = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
])
return transform
gan_dataset = ImageFolder(path, transform=get_transform())
gan_loader = DataLoader(gan_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=True)
return gan_loader
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=32,
help="Size of batch for Inception inference.")
parser.add_argument("--img_size", type=int, default=64,
help="Size of image that GAN is producing.")
parser.add_argument("--path", type=str, required=True,
help="Path of images generated by GAN.")
args = parser.parse_args()
gan_loader = get_loader(args.path, args.img_size, args.batch_size)
score = inception_score(gan_loader, args.batch_size)
print(score)