-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmetrics.py
84 lines (65 loc) · 2.8 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import cv2
import torch
import math
from tqdm import tqdm
import lpips
import glob
from skimage.metrics import structural_similarity as ssim
import argparse
import sys
import os
def calc_psnr_np(sr, hr, range=255.):
# sr = sr.transpose((2, 0, 1))
# hr = hr.transpose((2, 0, 1))
diff = (sr.astype(np.float32) - hr.astype(np.float32)) / range
mse = np.power(diff, 2).mean()
return -10 * math.log10(mse)
def lpips_norm(img):
img = img[:, :, :, np.newaxis].transpose((3, 2, 0, 1))
img = img / (255. / 2.) - 1
return torch.Tensor(img).to(device)
def calc_lpips(out, target, loss_fn_alex):
lpips_out = lpips_norm(out)
lpips_target = lpips_norm(target)
LPIPS = loss_fn_alex(lpips_out, lpips_target)
return LPIPS.detach().cpu().item()
def calc_metrics(out, target, loss_fn_alex):
psnr = calc_psnr_np(out, target)
SSIM = ssim(out, target, win_size=11, data_range=255, multichannel=True, gaussian_weights=True, channel_axis=2)
LPIPS = calc_lpips(out, target, loss_fn_alex)
return np.array([psnr, SSIM, LPIPS], dtype=float)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Metrics for argparse')
parser.add_argument('--name', type=str, required=True,
help='Name of the folder to save models and logs.')
parser.add_argument('--dataroot', type=str, default='/Dataset/Real-NAID/')
parser.add_argument('--device', default="0")
args = parser.parse_args()
device = torch.device("cuda:" + args.device if torch.cuda.is_available() else "cpu")
loss_fn_alex_v1 = lpips.LPIPS(net='alex', version='0.1').to(device)
root = sys.path[0]
files = [
root + '/ckpt/' + args.name,
]
for file in files:
print('Start to measure images in %s...' % (file))
# metrics = np.zeros([30*3, 3])
metrics = np.zeros([30*3, 3])
log_dir = '%s/log_metrics.txt' % (file)
f = open(log_dir, 'a')
i = 0
for image_file in tqdm(list(os.listdir(file + '/output/'))):
gt = cv2.imread(args.dataroot + 'test/gt/' + image_file[:-9]+'gt.png')[..., ::-1]
output = cv2.imread(file + '/output/'+ image_file)[..., ::-1]
metrics[i, 0:3] = calc_metrics(output, gt, loss_fn_alex_v1)
i = i + 1
mean_metrics = np.mean(metrics, axis=0)
print('\n File :\t %s \n' % (file))
print(' Original GT :\t PSNR = %.2f, SSIM = %.4f, LPIPS = %.3f \n'
% (mean_metrics[0], mean_metrics[1], mean_metrics[2]))
f.write('\n File :\t %s \n' % (file))
f.write(' Original GT :\t PSNR = %.2f, SSIM = %.4f, LPIPS = %.3f \n'
% (mean_metrics[0], mean_metrics[1], mean_metrics[2]))
f.flush()
f.close()