Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -23,13 +23,6 @@ USE_HALF_PRECISION = True
|
|
23 |
|
24 |
def load_dit_model(dit_size):
|
25 |
"""Load DiT model of specified size"""
|
26 |
-
#ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth"
|
27 |
-
ckpt_path = hf_hub_download(
|
28 |
-
repo_id = "kaupane/DiT-Wikiart",
|
29 |
-
filename = f"DiT_{dit_size}_final.pth"
|
30 |
-
)
|
31 |
-
if not os.path.exists(ckpt_path):
|
32 |
-
raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}")
|
33 |
|
34 |
# Configure model based on size
|
35 |
if dit_size == "S":
|
@@ -37,15 +30,13 @@ def load_dit_model(dit_size):
|
|
37 |
model.from_pretrained("kaupane/DiT-Wikiart-Small")
|
38 |
elif dit_size == "B":
|
39 |
model = DiT(num_blocks=12, hidden_size=640, num_heads=10)
|
|
|
40 |
elif dit_size == "L":
|
41 |
model = DiT(num_blocks=16, hidden_size=896, num_heads=14)
|
|
|
42 |
else:
|
43 |
raise ValueError(f"Invalid DiT size: {dit_size}")
|
44 |
|
45 |
-
# Load checkpoint
|
46 |
-
#checkpoint = torch.load(ckpt_path, map_location="cpu")
|
47 |
-
#model.load_state_dict(checkpoint["model_state_dict"])
|
48 |
-
|
49 |
return model
|
50 |
|
51 |
class DiffusionSampler:
|
|
|
23 |
|
24 |
def load_dit_model(dit_size):
|
25 |
"""Load DiT model of specified size"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Configure model based on size
|
28 |
if dit_size == "S":
|
|
|
30 |
model.from_pretrained("kaupane/DiT-Wikiart-Small")
|
31 |
elif dit_size == "B":
|
32 |
model = DiT(num_blocks=12, hidden_size=640, num_heads=10)
|
33 |
+
model.from_pretrained("kaupane/DiT-Wikiart-Base")
|
34 |
elif dit_size == "L":
|
35 |
model = DiT(num_blocks=16, hidden_size=896, num_heads=14)
|
36 |
+
model.from_pretrained("kaupane/DiT-Wikiart-Large")
|
37 |
else:
|
38 |
raise ValueError(f"Invalid DiT size: {dit_size}")
|
39 |
|
|
|
|
|
|
|
|
|
40 |
return model
|
41 |
|
42 |
class DiffusionSampler:
|