quantumiracle commited on
Commit
2848fdc
·
1 Parent(s): 2380c67
Files changed (1) hide show
  1. SUPIR/util.py +27 -13
SUPIR/util.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
6
  from torch.nn.functional import interpolate
7
  from omegaconf import OmegaConf
8
  from sgm.util import instantiate_from_config
9
-
10
 
11
  def get_state_dict(d):
12
  return d.get('state_dict', d)
@@ -30,30 +30,44 @@ def create_model(config_path):
30
  print(f'Loaded model config from [{config_path}]')
31
  return model
32
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def create_SUPIR_model(config_path, SUPIR_sign=None, load_default_setting=False):
35
  config = OmegaConf.load(config_path)
36
  model = instantiate_from_config(config.model).cpu()
37
  print(f'Loaded model config from [{config_path}]')
38
- if config.SDXL_CKPT is not None:
39
- model.load_state_dict(load_state_dict(config.SDXL_CKPT), strict=False)
40
- if config.SUPIR_CKPT is not None:
41
- model.load_state_dict(load_state_dict(config.SUPIR_CKPT), strict=False)
 
 
 
 
 
42
  if SUPIR_sign is not None:
43
  assert SUPIR_sign in ['F', 'Q']
44
- if SUPIR_sign == 'F':
45
- model.load_state_dict(load_state_dict(config.SUPIR_CKPT_F), strict=False)
46
- elif SUPIR_sign == 'Q':
47
- model.load_state_dict(load_state_dict(config.SUPIR_CKPT_Q), strict=False)
48
  if load_default_setting:
49
- default_setting = config.default_setting
50
- return model, default_setting
51
  return model
52
 
53
  def load_QF_ckpt(config_path):
54
  config = OmegaConf.load(config_path)
55
- ckpt_F = torch.load(config.SUPIR_CKPT_F, map_location='cpu')
56
- ckpt_Q = torch.load(config.SUPIR_CKPT_Q, map_location='cpu')
57
  return ckpt_Q, ckpt_F
58
 
59
 
 
6
  from torch.nn.functional import interpolate
7
  from omegaconf import OmegaConf
8
  from sgm.util import instantiate_from_config
9
+ from huggingface_hub import hf_hub_download
10
 
11
  def get_state_dict(d):
12
  return d.get('state_dict', d)
 
30
  print(f'Loaded model config from [{config_path}]')
31
  return model
32
 
33
+ def resolve_ckpt_path(path_or_hub):
34
+ if os.path.exists(path_or_hub):
35
+ return path_or_hub # local path
36
+ if "/" in path_or_hub and path_or_hub.endswith(".ckpt"):
37
+ # Assume format: repo_id/path/to/file.ckpt
38
+ parts = path_or_hub.split("/")
39
+ repo_id = "/".join(parts[:2])
40
+ filename = "/".join(parts[2:])
41
+ return hf_hub_download(repo_id=repo_id, filename=filename)
42
+ return path_or_hub # fallback
43
 
44
  def create_SUPIR_model(config_path, SUPIR_sign=None, load_default_setting=False):
45
  config = OmegaConf.load(config_path)
46
  model = instantiate_from_config(config.model).cpu()
47
  print(f'Loaded model config from [{config_path}]')
48
+
49
+ if config.get("SDXL_CKPT") is not None:
50
+ path = resolve_ckpt_path(config.SDXL_CKPT)
51
+ model.load_state_dict(torch.load(path, map_location='cpu'), strict=False)
52
+
53
+ if config.get("SUPIR_CKPT") is not None:
54
+ path = resolve_ckpt_path(config.SUPIR_CKPT)
55
+ model.load_state_dict(torch.load(path, map_location='cpu'), strict=False)
56
+
57
  if SUPIR_sign is not None:
58
  assert SUPIR_sign in ['F', 'Q']
59
+ key = f"SUPIR_CKPT_{SUPIR_sign}"
60
+ path = resolve_ckpt_path(config[key])
61
+ model.load_state_dict(torch.load(path, map_location='cpu'), strict=False)
62
+
63
  if load_default_setting:
64
+ return model, config.default_setting
 
65
  return model
66
 
67
  def load_QF_ckpt(config_path):
68
  config = OmegaConf.load(config_path)
69
+ ckpt_F = torch.load(resolve_ckpt_path(config.SUPIR_CKPT_F), map_location='cpu')
70
+ ckpt_Q = torch.load(resolve_ckpt_path(config.SUPIR_CKPT_Q), map_location='cpu')
71
  return ckpt_Q, ckpt_F
72
 
73