Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -27,16 +27,22 @@ def load_models(pretrained_model_path, device):
|
|
| 27 |
# Handle torch.hub checkpoint loading for CPU-only environments
|
| 28 |
map_location = torch.device("cpu") if device.type == "cpu" else None
|
| 29 |
|
| 30 |
-
# Load the UNet model and
|
| 31 |
unet = torch.hub.load(
|
| 32 |
repo_or_dir="aimagelab/multimodal-garment-designer",
|
| 33 |
source="github",
|
| 34 |
model="mgd",
|
| 35 |
pretrained=True,
|
| 36 |
dataset="dresscode", # Change to "vitonhd" if needed
|
| 37 |
-
map_location=map_location, # Ensure the model loads on CPU if needed
|
| 38 |
)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
# Move UNet to the appropriate device
|
| 41 |
unet = unet.to(device)
|
| 42 |
|
|
|
|
| 27 |
# Handle torch.hub checkpoint loading for CPU-only environments
|
| 28 |
map_location = torch.device("cpu") if device.type == "cpu" else None
|
| 29 |
|
| 30 |
+
# Load the UNet model and force map_location for state_dict loading
|
| 31 |
unet = torch.hub.load(
|
| 32 |
repo_or_dir="aimagelab/multimodal-garment-designer",
|
| 33 |
source="github",
|
| 34 |
model="mgd",
|
| 35 |
pretrained=True,
|
| 36 |
dataset="dresscode", # Change to "vitonhd" if needed
|
|
|
|
| 37 |
)
|
| 38 |
|
| 39 |
+
# Ensure the model state dict is mapped correctly to the CPU if needed
|
| 40 |
+
if device.type == "cpu":
|
| 41 |
+
checkpoint_url = unet.config.get("checkpoint")
|
| 42 |
+
if checkpoint_url:
|
| 43 |
+
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")
|
| 44 |
+
unet.load_state_dict(state_dict)
|
| 45 |
+
|
| 46 |
# Move UNet to the appropriate device
|
| 47 |
unet = unet.to(device)
|
| 48 |
|