Spaces:
Sleeping
Sleeping
Commit
·
ddec8ca
1
Parent(s):
85a3747
Set map_location for pre-trained weights
Browse files
app.py
CHANGED
|
@@ -4,9 +4,6 @@ app.py
|
|
| 4 |
An interactive demo of text-guided shape generation.
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
import os
|
| 8 |
-
os.system("pip install -e ./custom_wheels/salad-0.1-py3-none-any.whl")
|
| 9 |
-
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Literal
|
| 12 |
|
|
@@ -32,7 +29,10 @@ def load_model(
|
|
| 32 |
checkpoint_dir = Path(__file__).parent / "checkpoints"
|
| 33 |
c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml")
|
| 34 |
model = hydra.utils.instantiate(c)
|
| 35 |
-
ckpt = torch.load(
|
|
|
|
|
|
|
|
|
|
| 36 |
model.load_state_dict(ckpt)
|
| 37 |
model.eval()
|
| 38 |
for p in model.parameters(): p.requires_grad_(False)
|
|
|
|
| 4 |
An interactive demo of text-guided shape generation.
|
| 5 |
"""
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Literal
|
| 9 |
|
|
|
|
| 29 |
checkpoint_dir = Path(__file__).parent / "checkpoints"
|
| 30 |
c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml")
|
| 31 |
model = hydra.utils.instantiate(c)
|
| 32 |
+
ckpt = torch.load(
|
| 33 |
+
checkpoint_dir / f"{model_class}/state_only.ckpt",
|
| 34 |
+
map_location=device,
|
| 35 |
+
)
|
| 36 |
model.load_state_dict(ckpt)
|
| 37 |
model.eval()
|
| 38 |
for p in model.parameters(): p.requires_grad_(False)
|