Update OmniGen/model.py
Browse files- OmniGen/model.py +7 -2
OmniGen/model.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Dict
|
|
| 9 |
from diffusers.loaders import PeftAdapterMixin
|
| 10 |
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 11 |
from huggingface_hub import snapshot_download
|
|
|
|
| 12 |
|
| 13 |
from OmniGen.transformer import Phi3Config, Phi3Transformer
|
| 14 |
|
|
@@ -187,14 +188,18 @@ class OmniGen(nn.Module, PeftAdapterMixin):
|
|
| 187 |
|
| 188 |
@classmethod
|
| 189 |
def from_pretrained(cls, model_name):
|
| 190 |
-
if not os.path.exists(
|
| 191 |
cache_folder = os.getenv('HF_HUB_CACHE')
|
| 192 |
model_name = snapshot_download(repo_id=model_name,
|
| 193 |
cache_dir=cache_folder,
|
| 194 |
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
| 195 |
config = Phi3Config.from_pretrained(model_name)
|
| 196 |
model = cls(config)
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
model.load_state_dict(ckpt)
|
| 199 |
return model
|
| 200 |
|
|
|
|
| 9 |
from diffusers.loaders import PeftAdapterMixin
|
| 10 |
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 11 |
from huggingface_hub import snapshot_download
|
| 12 |
+
from safetensors.torch import load_file
|
| 13 |
|
| 14 |
from OmniGen.transformer import Phi3Config, Phi3Transformer
|
| 15 |
|
|
|
|
| 188 |
|
| 189 |
@classmethod
|
| 190 |
def from_pretrained(cls, model_name):
|
| 191 |
+
if not os.path.exists(model_name):
|
| 192 |
cache_folder = os.getenv('HF_HUB_CACHE')
|
| 193 |
model_name = snapshot_download(repo_id=model_name,
|
| 194 |
cache_dir=cache_folder,
|
| 195 |
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
| 196 |
config = Phi3Config.from_pretrained(model_name)
|
| 197 |
model = cls(config)
|
| 198 |
+
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
|
| 199 |
+
print("Loading safetensors")
|
| 200 |
+
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
|
| 201 |
+
else:
|
| 202 |
+
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
| 203 |
model.load_state_dict(ckpt)
|
| 204 |
return model
|
| 205 |
|