gaur3009 commited on
Commit
efac922
Β·
verified Β·
1 Parent(s): 221b341

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -28,12 +28,14 @@ class TransformNet(nn.Module):
28
  self.fc = nn.Linear(1024, 512)
29
 
30
  def forward(self, x):
31
- x = x.unsqueeze(-1) # βœ… Add extra dimension for Conv2d
 
 
32
  x = self.conv1(x)
33
  x = self.conv2(x)
34
- x = x.squeeze(-1) # βœ… Remove extra dimension before Conv1d
35
- x = self.conv3(x) # βœ… Now it correctly matches `[1024, 128, 1]`
36
- return self.fc(x.max(dim=-1)[0]) # βœ… Fix pooling
37
 
38
  class PointCloudEncoder(nn.Module):
39
  """Point Cloud Encoder (pc_enc)"""
@@ -125,11 +127,11 @@ model.eval()
125
  def segment_dress(image):
126
  """Detect and segment the dress from the image."""
127
  img = Image.fromarray(image).convert("RGB")
128
- img = np.array(img).transpose(2, 0, 1) # Convert to tensor format
129
- img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) / 255.0 # Normalize
130
 
131
  with torch.no_grad():
132
- mask = model(img, clothing_classes=torch.arange(18))
133
  mask = mask.squeeze().numpy()
134
 
135
  mask = (mask > 0.5).astype(np.uint8) * 255 # Convert to binary mask
 
28
  self.fc = nn.Linear(1024, 512)
29
 
30
  def forward(self, x):
31
+ if x.dim() == 5:
32
+ x = x.squeeze(-1)
33
+
34
  x = self.conv1(x)
35
  x = self.conv2(x)
36
+ x = self.conv3(x)
37
+ return self.fc(x.max(dim=-1)[0]) # βœ… Ensure correct pooling
38
+
39
 
40
  class PointCloudEncoder(nn.Module):
41
  """Point Cloud Encoder (pc_enc)"""
 
127
  def segment_dress(image):
128
  """Detect and segment the dress from the image."""
129
  img = Image.fromarray(image).convert("RGB")
130
+ img = np.array(img).transpose(2, 0, 1) # Convert to [C, H, W]
131
+ img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) / 255.0 # Normalize to [1, C, H, W]
132
 
133
  with torch.no_grad():
134
+ mask = model(img.squeeze(-1), clothing_classes=torch.arange(18)) # βœ… Remove extra dimension
135
  mask = mask.squeeze().numpy()
136
 
137
  mask = (mask > 0.5).astype(np.uint8) * 255 # Convert to binary mask