gaur3009 commited on
Commit
ed2f309
Β·
verified Β·
1 Parent(s): b99093e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -29,14 +29,16 @@ class TransformNet(nn.Module):
29
 
30
  def forward(self, x):
31
  if x.dim() == 5:
32
- x = x.squeeze(-1)
33
-
34
  x = self.conv1(x)
35
  x = self.conv2(x)
36
- x = self.conv3(x)
 
 
 
37
  return self.fc(x.max(dim=-1)[0]) # βœ… Ensure correct pooling
38
 
39
-
40
  class PointCloudEncoder(nn.Module):
41
  """Point Cloud Encoder (pc_enc)"""
42
  def __init__(self):
@@ -61,16 +63,19 @@ class PointCloudEncoder(nn.Module):
61
  ])
62
  self.lin_global = nn.Linear(64, 128)
63
 
64
- def forward(self, x):
65
  x = self.transform_net(x)
 
66
  for i, conv in enumerate(self.convs):
67
  if i < 2:
68
- x = conv(x) # βœ… Conv2d
69
  else:
70
- x = x.squeeze(-1) # βœ… Remove extra dimension before Conv1d
71
- x = conv(x) # βœ… Conv1d (Matches `[64, 128, 1]`)
 
72
  return self.lin_global(x.max(dim=-1)[0]) # βœ… Fix pooling
73
 
 
74
  class GarmentEncoder(nn.Module):
75
  """Garment Feature Encoder (garm_enc)"""
76
  def __init__(self, num_classes=18, feature_dim=64):
 
29
 
30
  def forward(self, x):
31
  if x.dim() == 5:
32
+ x = x.squeeze(-1) # βœ… Remove extra dimension if exists
33
+
34
  x = self.conv1(x)
35
  x = self.conv2(x)
36
+
37
+ x = x.squeeze(-1) # βœ… Ensure shape is [batch, channels, length] before Conv1d
38
+ x = self.conv3(x) # βœ… Now Conv1d receives correct input shape [batch, channels, length]
39
+
40
  return self.fc(x.max(dim=-1)[0]) # βœ… Ensure correct pooling
41
 
 
42
  class PointCloudEncoder(nn.Module):
43
  """Point Cloud Encoder (pc_enc)"""
44
  def __init__(self):
 
63
  ])
64
  self.lin_global = nn.Linear(64, 128)
65
 
66
+ def forward(self, x):
67
  x = self.transform_net(x)
68
+
69
  for i, conv in enumerate(self.convs):
70
  if i < 2:
71
+ x = conv(x) # βœ… Conv2d keeps 4D
72
  else:
73
+ x = x.squeeze(-1) # βœ… Ensure shape is [batch, channels, length] before Conv1d
74
+ x = conv(x) # βœ… Conv1d now works with the correct input
75
+
76
  return self.lin_global(x.max(dim=-1)[0]) # βœ… Fix pooling
77
 
78
+
79
  class GarmentEncoder(nn.Module):
80
  """Garment Feature Encoder (garm_enc)"""
81
  def __init__(self, num_classes=18, feature_dim=64):