kaupane commited on
Commit
db88ac3
·
verified ·
1 Parent(s): a7df152

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -11
app.py CHANGED
@@ -23,13 +23,6 @@ USE_HALF_PRECISION = True
23
 
24
  def load_dit_model(dit_size):
25
  """Load DiT model of specified size"""
26
- #ckpt_path = f"./ckpts/DiT_{dit_size}_final.pth"
27
- ckpt_path = hf_hub_download(
28
- repo_id = "kaupane/DiT-Wikiart",
29
- filename = f"DiT_{dit_size}_final.pth"
30
- )
31
- if not os.path.exists(ckpt_path):
32
- raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}")
33
 
34
  # Configure model based on size
35
  if dit_size == "S":
@@ -37,15 +30,13 @@ def load_dit_model(dit_size):
37
  model.from_pretrained("kaupane/DiT-Wikiart-Small")
38
  elif dit_size == "B":
39
  model = DiT(num_blocks=12, hidden_size=640, num_heads=10)
 
40
  elif dit_size == "L":
41
  model = DiT(num_blocks=16, hidden_size=896, num_heads=14)
 
42
  else:
43
  raise ValueError(f"Invalid DiT size: {dit_size}")
44
 
45
- # Load checkpoint
46
- #checkpoint = torch.load(ckpt_path, map_location="cpu")
47
- #model.load_state_dict(checkpoint["model_state_dict"])
48
-
49
  return model
50
 
51
  class DiffusionSampler:
 
23
 
24
  def load_dit_model(dit_size):
25
  """Load DiT model of specified size"""
 
 
 
 
 
 
 
26
 
27
  # Configure model based on size
28
  if dit_size == "S":
 
30
  model.from_pretrained("kaupane/DiT-Wikiart-Small")
31
  elif dit_size == "B":
32
  model = DiT(num_blocks=12, hidden_size=640, num_heads=10)
33
+ model.from_pretrained("kaupane/DiT-Wikiart-Base")
34
  elif dit_size == "L":
35
  model = DiT(num_blocks=16, hidden_size=896, num_heads=14)
36
+ model.from_pretrained("kaupane/DiT-Wikiart-Large")
37
  else:
38
  raise ValueError(f"Invalid DiT size: {dit_size}")
39
 
 
 
 
 
40
  return model
41
 
42
  class DiffusionSampler: