Update app.py
Browse files
app.py
CHANGED
@@ -7,50 +7,82 @@ from PIL import Image
|
|
7 |
import cv2
|
8 |
|
9 |
class TransformNet(nn.Module):
|
10 |
-
|
|
|
11 |
super(TransformNet, self).__init__()
|
12 |
self.conv1 = nn.Sequential(
|
13 |
-
nn.Conv2d(input_dim, 64, kernel_size=1,
|
14 |
nn.BatchNorm2d(64),
|
15 |
nn.ReLU()
|
16 |
)
|
17 |
self.conv2 = nn.Sequential(
|
18 |
-
nn.Conv2d(64, 128, kernel_size=1,
|
19 |
nn.BatchNorm2d(128),
|
20 |
nn.ReLU()
|
21 |
)
|
22 |
self.conv3 = nn.Sequential(
|
23 |
-
nn.
|
24 |
-
nn.
|
25 |
nn.ReLU()
|
26 |
)
|
27 |
self.fc = nn.Linear(1024, 512)
|
28 |
|
29 |
def forward(self, x):
|
30 |
-
x = x.unsqueeze(-1)
|
31 |
x = self.conv1(x)
|
32 |
x = self.conv2(x)
|
33 |
-
x =
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
class PointCloudEncoder(nn.Module):
|
|
|
37 |
def __init__(self):
|
38 |
super(PointCloudEncoder, self).__init__()
|
39 |
self.transform_net = TransformNet()
|
40 |
self.convs = nn.ModuleList([
|
41 |
nn.Sequential(
|
42 |
-
nn.Conv2d(512, 256, kernel_size=1,
|
43 |
nn.BatchNorm2d(256),
|
44 |
nn.ReLU()
|
45 |
),
|
46 |
nn.Sequential(
|
47 |
-
nn.Conv2d(256, 128, kernel_size=1,
|
48 |
nn.BatchNorm2d(128),
|
49 |
nn.ReLU()
|
50 |
),
|
51 |
nn.Sequential(
|
52 |
-
nn.
|
53 |
-
nn.
|
54 |
nn.ReLU()
|
55 |
)
|
56 |
])
|
@@ -58,9 +90,13 @@ class PointCloudEncoder(nn.Module):
|
|
58 |
|
59 |
def forward(self, x):
|
60 |
x = self.transform_net(x)
|
61 |
-
for conv in self.convs:
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
|
65 |
class GarmentEncoder(nn.Module):
|
66 |
"""Garment Feature Encoder (garm_enc)"""
|
@@ -110,7 +146,7 @@ class CloseNet(nn.Module):
|
|
110 |
return self.segm_dec(features)
|
111 |
|
112 |
# Load Pretrained Model
|
113 |
-
model_path = "
|
114 |
model = CloseNet()
|
115 |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False)
|
116 |
model.eval()
|
|
|
7 |
import cv2
|
8 |
|
9 |
class TransformNet(nn.Module):
|
10 |
+
"""Transformation Network for PointCloud Encoding"""
|
11 |
+
def __init__(self, input_dim=6): # β
Ensure input has 6 channels
|
12 |
super(TransformNet, self).__init__()
|
13 |
self.conv1 = nn.Sequential(
|
14 |
+
nn.Conv2d(input_dim, 64, kernel_size=(1, 1)), # β
Conv2d (Matches checkpoint)
|
15 |
nn.BatchNorm2d(64),
|
16 |
nn.ReLU()
|
17 |
)
|
18 |
self.conv2 = nn.Sequential(
|
19 |
+
nn.Conv2d(64, 128, kernel_size=(1, 1)), # β
Conv2d (Matches checkpoint)
|
20 |
nn.BatchNorm2d(128),
|
21 |
nn.ReLU()
|
22 |
)
|
23 |
self.conv3 = nn.Sequential(
|
24 |
+
nn.Conv1d(128, 1024, kernel_size=1), # β
Conv1d to match `[1024, 128, 1]`
|
25 |
+
nn.BatchNorm1d(1024),
|
26 |
nn.ReLU()
|
27 |
)
|
28 |
self.fc = nn.Linear(1024, 512)
|
29 |
|
30 |
def forward(self, x):
|
31 |
+
x = x.unsqueeze(-1) # β
Add extra dimension for Conv2d
|
32 |
x = self.conv1(x)
|
33 |
x = self.conv2(x)
|
34 |
+
x = x.squeeze(-1) # β
Remove extra dimension before Conv1d
|
35 |
+
x = self.conv3(x) # β
Now it correctly matches `[1024, 128, 1]`
|
36 |
+
return self.fc(x.max(dim=-1)[0]) # β
Fix pooling
|
37 |
+
|
38 |
+
class TransformNet(nn.Module):
|
39 |
+
"""Transformation Network for PointCloud Encoding"""
|
40 |
+
def __init__(self, input_dim=6): # β
Ensure input has 6 channels
|
41 |
+
super(TransformNet, self).__init__()
|
42 |
+
self.conv1 = nn.Sequential(
|
43 |
+
nn.Conv2d(input_dim, 64, kernel_size=(1, 1)), # β
Conv2d (Matches checkpoint)
|
44 |
+
nn.BatchNorm2d(64),
|
45 |
+
nn.ReLU()
|
46 |
+
)
|
47 |
+
self.conv2 = nn.Sequential(
|
48 |
+
nn.Conv2d(64, 128, kernel_size=(1, 1)), # β
Conv2d (Matches checkpoint)
|
49 |
+
nn.BatchNorm2d(128),
|
50 |
+
nn.ReLU()
|
51 |
+
)
|
52 |
+
self.conv3 = nn.Sequential(
|
53 |
+
nn.Conv1d(128, 1024, kernel_size=1), # β
Conv1d to match `[1024, 128, 1]`
|
54 |
+
nn.BatchNorm1d(1024),
|
55 |
+
nn.ReLU()
|
56 |
+
)
|
57 |
+
self.fc = nn.Linear(1024, 512)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = x.unsqueeze(-1) # β
Add extra dimension for Conv2d
|
61 |
+
x = self.conv1(x)
|
62 |
+
x = self.conv2(x)
|
63 |
+
x = x.squeeze(-1) # β
Remove extra dimension before Conv1d
|
64 |
+
x = self.conv3(x) # β
Now it correctly matches `[1024, 128, 1]`
|
65 |
+
return self.fc(x.max(dim=-1)[0]) # β
Fix pooling
|
66 |
|
67 |
class PointCloudEncoder(nn.Module):
|
68 |
+
"""Point Cloud Encoder (pc_enc)"""
|
69 |
def __init__(self):
|
70 |
super(PointCloudEncoder, self).__init__()
|
71 |
self.transform_net = TransformNet()
|
72 |
self.convs = nn.ModuleList([
|
73 |
nn.Sequential(
|
74 |
+
nn.Conv2d(512, 256, kernel_size=(1, 1)), # β
Conv2d (Matches checkpoint)
|
75 |
nn.BatchNorm2d(256),
|
76 |
nn.ReLU()
|
77 |
),
|
78 |
nn.Sequential(
|
79 |
+
nn.Conv2d(256, 128, kernel_size=(1, 1)), # β
Conv2d (Matches checkpoint)
|
80 |
nn.BatchNorm2d(128),
|
81 |
nn.ReLU()
|
82 |
),
|
83 |
nn.Sequential(
|
84 |
+
nn.Conv1d(128, 64, kernel_size=1), # β
Conv1d to match `[64, 128, 1]`
|
85 |
+
nn.BatchNorm1d(64),
|
86 |
nn.ReLU()
|
87 |
)
|
88 |
])
|
|
|
90 |
|
91 |
def forward(self, x):
|
92 |
x = self.transform_net(x)
|
93 |
+
for i, conv in enumerate(self.convs):
|
94 |
+
if i < 2:
|
95 |
+
x = conv(x) # β
Conv2d
|
96 |
+
else:
|
97 |
+
x = x.squeeze(-1) # β
Remove extra dimension before Conv1d
|
98 |
+
x = conv(x) # β
Conv1d (Matches `[64, 128, 1]`)
|
99 |
+
return self.lin_global(x.max(dim=-1)[0]) # β
Fix pooling
|
100 |
|
101 |
class GarmentEncoder(nn.Module):
|
102 |
"""Garment Feature Encoder (garm_enc)"""
|
|
|
146 |
return self.segm_dec(features)
|
147 |
|
148 |
# Load Pretrained Model
|
149 |
+
model_path = "/content/closenet.pth"
|
150 |
model = CloseNet()
|
151 |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False)
|
152 |
model.eval()
|