gaur3009 commited on
Commit
89b4f01
·
verified ·
1 Parent(s): f994235

Upload infer.py

Browse files
Files changed (1) hide show
  1. infer.py +86 -0
infer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from tqdm import tqdm
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ import warnings
8
+
9
+ warnings.filterwarnings("ignore", category=FutureWarning)
10
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import torchvision.transforms as transforms
15
+
16
+ from data.base_dataset import Normalize_image
17
+ from utils.saving_utils import load_checkpoint_mgpu
18
+
19
+ from networks import U2NET
20
+
21
+ device = "cuda"
22
+
23
+ image_dir = "input_images"
24
+ result_dir = "output_images"
25
+ checkpoint_path = os.path.join("trained_checkpoint", "cloth_segm_u2net_latest.pth")
26
+ do_palette = True
27
+
28
+
29
+ def get_palette(num_cls):
30
+ """Returns the color map for visualizing the segmentation mask.
31
+ Args:
32
+ num_cls: Number of classes
33
+ Returns:
34
+ The color map
35
+ """
36
+ n = num_cls
37
+ palette = [0] * (n * 3)
38
+ for j in range(0, n):
39
+ lab = j
40
+ palette[j * 3 + 0] = 0
41
+ palette[j * 3 + 1] = 0
42
+ palette[j * 3 + 2] = 0
43
+ i = 0
44
+ while lab:
45
+ palette[j * 3 + 0] |= ((lab >> 0) & 1) << (7 - i)
46
+ palette[j * 3 + 1] |= ((lab >> 1) & 1) << (7 - i)
47
+ palette[j * 3 + 2] |= ((lab >> 2) & 1) << (7 - i)
48
+ i += 1
49
+ lab >>= 3
50
+ return palette
51
+
52
+
53
+ transforms_list = []
54
+ transforms_list += [transforms.ToTensor()]
55
+ transforms_list += [Normalize_image(0.5, 0.5)]
56
+ transform_rgb = transforms.Compose(transforms_list)
57
+
58
+ net = U2NET(in_ch=3, out_ch=4)
59
+ net = load_checkpoint_mgpu(net, checkpoint_path)
60
+ net = net.to(device)
61
+ net = net.eval()
62
+
63
+ palette = get_palette(4)
64
+
65
+ images_list = sorted(os.listdir(image_dir))
66
+ pbar = tqdm(total=len(images_list))
67
+ for image_name in images_list:
68
+ img = Image.open(os.path.join(image_dir, image_name)).convert("RGB")
69
+ image_tensor = transform_rgb(img)
70
+ image_tensor = torch.unsqueeze(image_tensor, 0)
71
+
72
+ output_tensor = net(image_tensor.to(device))
73
+ output_tensor = F.log_softmax(output_tensor[0], dim=1)
74
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
75
+ output_tensor = torch.squeeze(output_tensor, dim=0)
76
+ output_tensor = torch.squeeze(output_tensor, dim=0)
77
+ output_arr = output_tensor.cpu().numpy()
78
+
79
+ output_img = Image.fromarray(output_arr.astype("uint8"), mode="L")
80
+ if do_palette:
81
+ output_img.putpalette(palette)
82
+ output_img.save(os.path.join(result_dir, image_name[:-3] + "png"))
83
+
84
+ pbar.update(1)
85
+
86
+ pbar.close()