quantumiracle commited on
Commit
ccfcf8d
·
1 Parent(s): a552667
Files changed (1) hide show
  1. SUPIR/util.py +23 -13
SUPIR/util.py CHANGED
@@ -8,6 +8,8 @@ from omegaconf import OmegaConf
8
  from sgm.util import instantiate_from_config
9
  from huggingface_hub import hf_hub_download
10
 
 
 
11
  def get_state_dict(d):
12
  return d.get('state_dict', d)
13
 
@@ -30,12 +32,14 @@ def create_model(config_path):
30
  print(f'Loaded model config from [{config_path}]')
31
  return model
32
 
 
33
  def resolve_ckpt_path(path_or_hub):
 
 
34
  if os.path.exists(path_or_hub):
35
- return path_or_hub # Local path
36
 
37
- # Handle Hugging Face Hub repo/file pattern
38
- if "/" in path_or_hub and path_or_hub.endswith((".ckpt", ".safetensors", ".bin")):
39
  parts = path_or_hub.split("/")
40
  repo_id = "/".join(parts[:2])
41
  filename = "/".join(parts[2:])
@@ -43,24 +47,31 @@ def resolve_ckpt_path(path_or_hub):
43
 
44
  raise FileNotFoundError(f"Could not resolve checkpoint path: {path_or_hub}")
45
 
 
 
 
 
 
 
46
  def create_SUPIR_model(config_path, SUPIR_sign=None, load_default_setting=False):
47
  config = OmegaConf.load(config_path)
48
  model = instantiate_from_config(config.model).cpu()
49
  print(f'Loaded model config from [{config_path}]')
50
 
51
- if config.get("SDXL_CKPT") is not None:
52
  path = resolve_ckpt_path(config.SDXL_CKPT)
53
- model.load_state_dict(torch.load(path, map_location='cpu'), strict=False)
54
 
55
- if config.get("SUPIR_CKPT") is not None:
56
  path = resolve_ckpt_path(config.SUPIR_CKPT)
57
- model.load_state_dict(torch.load(path, map_location='cpu'), strict=False)
58
 
59
  if SUPIR_sign is not None:
60
- assert SUPIR_sign in ['F', 'Q']
61
  key = f"SUPIR_CKPT_{SUPIR_sign}"
62
- path = resolve_ckpt_path(config[key])
63
- model.load_state_dict(torch.load(path, map_location='cpu'), strict=False)
 
64
 
65
  if load_default_setting:
66
  return model, config.default_setting
@@ -68,11 +79,10 @@ def create_SUPIR_model(config_path, SUPIR_sign=None, load_default_setting=False)
68
 
69
  def load_QF_ckpt(config_path):
70
  config = OmegaConf.load(config_path)
71
- ckpt_F = torch.load(resolve_ckpt_path(config.SUPIR_CKPT_F), map_location='cpu')
72
- ckpt_Q = torch.load(resolve_ckpt_path(config.SUPIR_CKPT_Q), map_location='cpu')
73
  return ckpt_Q, ckpt_F
74
 
75
-
76
  def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None):
77
  '''
78
  PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]
 
8
  from sgm.util import instantiate_from_config
9
  from huggingface_hub import hf_hub_download
10
 
11
+ from safetensors.torch import load_file as load_safetensors
12
+
13
  def get_state_dict(d):
14
  return d.get('state_dict', d)
15
 
 
32
  print(f'Loaded model config from [{config_path}]')
33
  return model
34
 
35
+
36
  def resolve_ckpt_path(path_or_hub):
37
+ path_or_hub = path_or_hub.strip()
38
+
39
  if os.path.exists(path_or_hub):
40
+ return path_or_hub
41
 
42
+ if "/" in path_or_hub and path_or_hub.endswith((".ckpt", ".pt", ".bin", ".safetensors")):
 
43
  parts = path_or_hub.split("/")
44
  repo_id = "/".join(parts[:2])
45
  filename = "/".join(parts[2:])
 
47
 
48
  raise FileNotFoundError(f"Could not resolve checkpoint path: {path_or_hub}")
49
 
50
+ def load_checkpoint(path):
51
+ if path.endswith(".safetensors"):
52
+ return load_safetensors(path, device='cpu')
53
+ else:
54
+ return torch.load(path, map_location='cpu')
55
+
56
  def create_SUPIR_model(config_path, SUPIR_sign=None, load_default_setting=False):
57
  config = OmegaConf.load(config_path)
58
  model = instantiate_from_config(config.model).cpu()
59
  print(f'Loaded model config from [{config_path}]')
60
 
61
+ if config.get("SDXL_CKPT"):
62
  path = resolve_ckpt_path(config.SDXL_CKPT)
63
+ model.load_state_dict(load_checkpoint(path), strict=False)
64
 
65
+ if config.get("SUPIR_CKPT"):
66
  path = resolve_ckpt_path(config.SUPIR_CKPT)
67
+ model.load_state_dict(load_checkpoint(path), strict=False)
68
 
69
  if SUPIR_sign is not None:
70
+ assert SUPIR_sign in ['F', 'Q'], "SUPIR_sign must be 'F' or 'Q'"
71
  key = f"SUPIR_CKPT_{SUPIR_sign}"
72
+ if config.get(key):
73
+ path = resolve_ckpt_path(config[key])
74
+ model.load_state_dict(load_checkpoint(path), strict=False)
75
 
76
  if load_default_setting:
77
  return model, config.default_setting
 
79
 
80
  def load_QF_ckpt(config_path):
81
  config = OmegaConf.load(config_path)
82
+ ckpt_F = load_checkpoint(resolve_ckpt_path(config.SUPIR_CKPT_F))
83
+ ckpt_Q = load_checkpoint(resolve_ckpt_path(config.SUPIR_CKPT_Q))
84
  return ckpt_Q, ckpt_F
85
 
 
86
  def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None):
87
  '''
88
  PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]