gaur3009 commited on
Commit
0083878
Β·
verified Β·
1 Parent(s): aa15348

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -16
app.py CHANGED
@@ -7,50 +7,82 @@ from PIL import Image
7
  import cv2
8
 
9
  class TransformNet(nn.Module):
10
- def __init__(self, input_dim=6):
 
11
  super(TransformNet, self).__init__()
12
  self.conv1 = nn.Sequential(
13
- nn.Conv2d(input_dim, 64, kernel_size=1, stride=1),
14
  nn.BatchNorm2d(64),
15
  nn.ReLU()
16
  )
17
  self.conv2 = nn.Sequential(
18
- nn.Conv2d(64, 128, kernel_size=1, stride=1),
19
  nn.BatchNorm2d(128),
20
  nn.ReLU()
21
  )
22
  self.conv3 = nn.Sequential(
23
- nn.Conv2d(128, 1024, kernel_size=1, stride=1),
24
- nn.BatchNorm2d(1024),
25
  nn.ReLU()
26
  )
27
  self.fc = nn.Linear(1024, 512)
28
 
29
  def forward(self, x):
30
- x = x.unsqueeze(-1)
31
  x = self.conv1(x)
32
  x = self.conv2(x)
33
- x = self.conv3(x)
34
- return self.fc(x.max(dim=-1)[0].squeeze(-1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  class PointCloudEncoder(nn.Module):
 
37
  def __init__(self):
38
  super(PointCloudEncoder, self).__init__()
39
  self.transform_net = TransformNet()
40
  self.convs = nn.ModuleList([
41
  nn.Sequential(
42
- nn.Conv2d(512, 256, kernel_size=1, stride=1),
43
  nn.BatchNorm2d(256),
44
  nn.ReLU()
45
  ),
46
  nn.Sequential(
47
- nn.Conv2d(256, 128, kernel_size=1, stride=1),
48
  nn.BatchNorm2d(128),
49
  nn.ReLU()
50
  ),
51
  nn.Sequential(
52
- nn.Conv2d(128, 64, kernel_size=1, stride=1),
53
- nn.BatchNorm2d(64),
54
  nn.ReLU()
55
  )
56
  ])
@@ -58,9 +90,13 @@ class PointCloudEncoder(nn.Module):
58
 
59
  def forward(self, x):
60
  x = self.transform_net(x)
61
- for conv in self.convs:
62
- x = conv(x)
63
- return self.lin_global(x.max(dim=-1)[0].squeeze(-1))
 
 
 
 
64
 
65
  class GarmentEncoder(nn.Module):
66
  """Garment Feature Encoder (garm_enc)"""
@@ -110,7 +146,7 @@ class CloseNet(nn.Module):
110
  return self.segm_dec(features)
111
 
112
  # Load Pretrained Model
113
- model_path = "model_arch/closenet.pth"
114
  model = CloseNet()
115
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False)
116
  model.eval()
 
7
  import cv2
8
 
9
  class TransformNet(nn.Module):
10
+ """Transformation Network for PointCloud Encoding"""
11
+ def __init__(self, input_dim=6): # βœ… Ensure input has 6 channels
12
  super(TransformNet, self).__init__()
13
  self.conv1 = nn.Sequential(
14
+ nn.Conv2d(input_dim, 64, kernel_size=(1, 1)), # βœ… Conv2d (Matches checkpoint)
15
  nn.BatchNorm2d(64),
16
  nn.ReLU()
17
  )
18
  self.conv2 = nn.Sequential(
19
+ nn.Conv2d(64, 128, kernel_size=(1, 1)), # βœ… Conv2d (Matches checkpoint)
20
  nn.BatchNorm2d(128),
21
  nn.ReLU()
22
  )
23
  self.conv3 = nn.Sequential(
24
+ nn.Conv1d(128, 1024, kernel_size=1), # βœ… Conv1d to match `[1024, 128, 1]`
25
+ nn.BatchNorm1d(1024),
26
  nn.ReLU()
27
  )
28
  self.fc = nn.Linear(1024, 512)
29
 
30
  def forward(self, x):
31
+ x = x.unsqueeze(-1) # βœ… Add extra dimension for Conv2d
32
  x = self.conv1(x)
33
  x = self.conv2(x)
34
+ x = x.squeeze(-1) # βœ… Remove extra dimension before Conv1d
35
+ x = self.conv3(x) # βœ… Now it correctly matches `[1024, 128, 1]`
36
+ return self.fc(x.max(dim=-1)[0]) # βœ… Fix pooling
37
+
38
+ class TransformNet(nn.Module):
39
+ """Transformation Network for PointCloud Encoding"""
40
+ def __init__(self, input_dim=6): # βœ… Ensure input has 6 channels
41
+ super(TransformNet, self).__init__()
42
+ self.conv1 = nn.Sequential(
43
+ nn.Conv2d(input_dim, 64, kernel_size=(1, 1)), # βœ… Conv2d (Matches checkpoint)
44
+ nn.BatchNorm2d(64),
45
+ nn.ReLU()
46
+ )
47
+ self.conv2 = nn.Sequential(
48
+ nn.Conv2d(64, 128, kernel_size=(1, 1)), # βœ… Conv2d (Matches checkpoint)
49
+ nn.BatchNorm2d(128),
50
+ nn.ReLU()
51
+ )
52
+ self.conv3 = nn.Sequential(
53
+ nn.Conv1d(128, 1024, kernel_size=1), # βœ… Conv1d to match `[1024, 128, 1]`
54
+ nn.BatchNorm1d(1024),
55
+ nn.ReLU()
56
+ )
57
+ self.fc = nn.Linear(1024, 512)
58
+
59
+ def forward(self, x):
60
+ x = x.unsqueeze(-1) # βœ… Add extra dimension for Conv2d
61
+ x = self.conv1(x)
62
+ x = self.conv2(x)
63
+ x = x.squeeze(-1) # βœ… Remove extra dimension before Conv1d
64
+ x = self.conv3(x) # βœ… Now it correctly matches `[1024, 128, 1]`
65
+ return self.fc(x.max(dim=-1)[0]) # βœ… Fix pooling
66
 
67
  class PointCloudEncoder(nn.Module):
68
+ """Point Cloud Encoder (pc_enc)"""
69
  def __init__(self):
70
  super(PointCloudEncoder, self).__init__()
71
  self.transform_net = TransformNet()
72
  self.convs = nn.ModuleList([
73
  nn.Sequential(
74
+ nn.Conv2d(512, 256, kernel_size=(1, 1)), # βœ… Conv2d (Matches checkpoint)
75
  nn.BatchNorm2d(256),
76
  nn.ReLU()
77
  ),
78
  nn.Sequential(
79
+ nn.Conv2d(256, 128, kernel_size=(1, 1)), # βœ… Conv2d (Matches checkpoint)
80
  nn.BatchNorm2d(128),
81
  nn.ReLU()
82
  ),
83
  nn.Sequential(
84
+ nn.Conv1d(128, 64, kernel_size=1), # βœ… Conv1d to match `[64, 128, 1]`
85
+ nn.BatchNorm1d(64),
86
  nn.ReLU()
87
  )
88
  ])
 
90
 
91
  def forward(self, x):
92
  x = self.transform_net(x)
93
+ for i, conv in enumerate(self.convs):
94
+ if i < 2:
95
+ x = conv(x) # βœ… Conv2d
96
+ else:
97
+ x = x.squeeze(-1) # βœ… Remove extra dimension before Conv1d
98
+ x = conv(x) # βœ… Conv1d (Matches `[64, 128, 1]`)
99
+ return self.lin_global(x.max(dim=-1)[0]) # βœ… Fix pooling
100
 
101
  class GarmentEncoder(nn.Module):
102
  """Garment Feature Encoder (garm_enc)"""
 
146
  return self.segm_dec(features)
147
 
148
  # Load Pretrained Model
149
+ model_path = "/content/closenet.pth"
150
  model = CloseNet()
151
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")), strict=False)
152
  model.eval()