Update app.py
Browse files
app.py
CHANGED
@@ -28,12 +28,14 @@ class TransformNet(nn.Module):
|
|
28 |
self.fc = nn.Linear(1024, 512)
|
29 |
|
30 |
def forward(self, x):
|
31 |
-
|
|
|
|
|
32 |
x = self.conv1(x)
|
33 |
x = self.conv2(x)
|
34 |
-
x =
|
35 |
-
|
36 |
-
|
37 |
|
38 |
class PointCloudEncoder(nn.Module):
|
39 |
"""Point Cloud Encoder (pc_enc)"""
|
@@ -125,11 +127,11 @@ model.eval()
|
|
125 |
def segment_dress(image):
|
126 |
"""Detect and segment the dress from the image."""
|
127 |
img = Image.fromarray(image).convert("RGB")
|
128 |
-
img = np.array(img).transpose(2, 0, 1) # Convert to
|
129 |
-
img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) / 255.0 # Normalize
|
130 |
|
131 |
with torch.no_grad():
|
132 |
-
mask = model(img, clothing_classes=torch.arange(18))
|
133 |
mask = mask.squeeze().numpy()
|
134 |
|
135 |
mask = (mask > 0.5).astype(np.uint8) * 255 # Convert to binary mask
|
|
|
28 |
self.fc = nn.Linear(1024, 512)
|
29 |
|
30 |
def forward(self, x):
|
31 |
+
if x.dim() == 5:
|
32 |
+
x = x.squeeze(-1)
|
33 |
+
|
34 |
x = self.conv1(x)
|
35 |
x = self.conv2(x)
|
36 |
+
x = self.conv3(x)
|
37 |
+
return self.fc(x.max(dim=-1)[0]) # β
Ensure correct pooling
|
38 |
+
|
39 |
|
40 |
class PointCloudEncoder(nn.Module):
|
41 |
"""Point Cloud Encoder (pc_enc)"""
|
|
|
127 |
def segment_dress(image):
|
128 |
"""Detect and segment the dress from the image."""
|
129 |
img = Image.fromarray(image).convert("RGB")
|
130 |
+
img = np.array(img).transpose(2, 0, 1) # Convert to [C, H, W]
|
131 |
+
img = torch.tensor(img, dtype=torch.float32).unsqueeze(0) / 255.0 # Normalize to [1, C, H, W]
|
132 |
|
133 |
with torch.no_grad():
|
134 |
+
mask = model(img.squeeze(-1), clothing_classes=torch.arange(18)) # β
Remove extra dimension
|
135 |
mask = mask.squeeze().numpy()
|
136 |
|
137 |
mask = (mask > 0.5).astype(np.uint8) * 255 # Convert to binary mask
|