-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
111 lines (83 loc) · 3.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# test.py
import os, time
import numpy as np
import glob
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from PIL import Image
from Denoise import Denoise
import argparse
import logging
logging.basicConfig(level=logging.INFO)
# Argument for Denoise
parser = argparse.ArgumentParser(description="Denoise")
parser.add_argument(
"--n_resblocks", type=int, default=32, help="number of residual blocks"
)
parser.add_argument("--n_feats", type=int, default=64, help="number of feature maps")
parser.add_argument("--res_scale", type=float, default=1, help="residual scaling")
parser.add_argument("--scale", type=str, default=1, help="super resolution scale")
parser.add_argument("--patch_size", type=int, default=300, help="output patch size")
parser.add_argument(
"--n_colors", type=int, default=3, help="number of input color channels to use"
)
parser.add_argument(
"--o_colors", type=int, default=3, help="number of output color channels to use"
)
args = parser.parse_args()
torch.manual_seed(0)
input_dir = "./test/low/"
m_path = "./Model/denoise-lerelu-ps-300-b-32/"
m_name = "denoise_e0035.pth"
result_dir = "./test/predicted/"
# get test IDs
test_fns = glob.glob(input_dir + "/*.png")
test_ids = []
for i in range(len(test_fns)):
_, test_fn = os.path.split(test_fns[i])
test_ids.append(test_fn[0:-4])
ps = args.patch_size # patch size for training
def load_image(path):
img = Image.open(path)
img = np.array(img).astype(np.float32) / 255.0
return img
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info("reached 1")
model = Denoise(args).to(device)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = nn.DataParallel(model)
model.load_state_dict(torch.load(m_path + m_name, map_location=device))
model.to(device)
if not os.path.isdir(result_dir):
os.makedirs(result_dir)
cnt = 0
with torch.no_grad():
logging.info("loop started")
for test_id in test_ids:
# test the first image in each sequence
in_files = glob.glob(input_dir + test_id + ".png")
for k in range(len(in_files)):
in_path = in_files[k]
input_full = load_image(in_path)
input_full = np.expand_dims(input_full, axis=0)
input_full = np.minimum(input_full, 1.0)
in_img = torch.from_numpy(input_full).permute(0, 3, 1, 2).to(device)
st = time.time()
cnt += 1
out_img = model(in_img)
print("%d\tTime: %.3f" % (cnt, time.time() - st))
output = out_img.permute(0, 2, 3, 1).cpu().data.numpy()
output = np.minimum(np.maximum(output, 0), 1)
output = output[0, :, :, :]
origin_full = input_full[0, :, :, :]
# print("psnr: %.4f" % psnr[-1])
if not os.path.isdir(result_dir):
os.makedirs(result_dir)
plt.imsave(result_dir + "%05d_00_out.png" % int(test_id), output)
print("\n\n---------------------------------")
print(f"output files are stored in {result_dir}")