Spaces:
Running
on
Zero
Running
on
Zero
NIRVANALAN
commited on
Commit
·
1fa8ce9
1
Parent(s):
c3a2df4
update
Browse files- configs/i23d_args.json +1 -1
- nsr/train_util_diffusion.py +26 -9
configs/i23d_args.json
CHANGED
|
@@ -30,7 +30,7 @@
|
|
| 30 |
"log_interval": 50,
|
| 31 |
"eval_interval": 5000,
|
| 32 |
"save_interval": 10000,
|
| 33 |
-
"resume_checkpoint": "/
|
| 34 |
"resume_cldm_checkpoint": "",
|
| 35 |
"resume_checkpoint_EG3D": "",
|
| 36 |
"use_fp16": false,
|
|
|
|
| 30 |
"log_interval": 50,
|
| 31 |
"eval_interval": 5000,
|
| 32 |
"save_interval": 10000,
|
| 33 |
+
"resume_checkpoint": "checkpoints/objaverse/objaverse-dit/i23d/model_joint_denoise_rec_model2990000.safetensors",
|
| 34 |
"resume_cldm_checkpoint": "",
|
| 35 |
"resume_checkpoint_EG3D": "",
|
| 36 |
"use_fp16": false,
|
nsr/train_util_diffusion.py
CHANGED
|
@@ -32,6 +32,8 @@ from guided_diffusion.train_util import (TrainLoop, calc_average_loss,
|
|
| 32 |
parse_resume_step_from_filename)
|
| 33 |
|
| 34 |
import dnnlib
|
|
|
|
|
|
|
| 35 |
|
| 36 |
from nsr.camera_utils import FOV_to_intrinsics, LookAtPoseSampler
|
| 37 |
|
|
@@ -758,25 +760,40 @@ class TrainLoopDiffusionWithRec(TrainLoop):
|
|
| 758 |
model=None,
|
| 759 |
model_name='ddpm',
|
| 760 |
resume_checkpoint=None):
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
|
| 765 |
if model is None:
|
| 766 |
model = self.model
|
| 767 |
|
| 768 |
-
if resume_checkpoint and Path(resume_checkpoint).exists():
|
| 769 |
if dist_util.get_rank() == 0:
|
| 770 |
# ! rank 0 return will cause all other ranks to hang
|
| 771 |
-
logger.log(
|
| 772 |
-
f"loading model from checkpoint: {resume_checkpoint}...")
|
| 773 |
map_location = {
|
| 774 |
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
|
| 775 |
} # configure map_location properly
|
| 776 |
-
|
| 777 |
logger.log(f'mark {model_name} loading ')
|
| 778 |
-
|
| 779 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
logger.log(f'mark {model_name} loading finished')
|
| 781 |
|
| 782 |
model_state_dict = model.state_dict()
|
|
|
|
| 32 |
parse_resume_step_from_filename)
|
| 33 |
|
| 34 |
import dnnlib
|
| 35 |
+
from safetensors.torch import load_file
|
| 36 |
+
from huggingface_hub import hf_hub_download
|
| 37 |
|
| 38 |
from nsr.camera_utils import FOV_to_intrinsics, LookAtPoseSampler
|
| 39 |
|
|
|
|
| 760 |
model=None,
|
| 761 |
model_name='ddpm',
|
| 762 |
resume_checkpoint=None):
|
| 763 |
+
# load safetensors from hf
|
| 764 |
+
|
| 765 |
+
hf_loading = '.safetensors' in self.resume_checkpoint
|
| 766 |
+
if not hf_loading:
|
| 767 |
+
if resume_checkpoint is None:
|
| 768 |
+
resume_checkpoint, self.resume_step = find_resume_checkpoint(
|
| 769 |
+
self.resume_checkpoint, model_name) or self.resume_checkpoint
|
| 770 |
|
| 771 |
if model is None:
|
| 772 |
model = self.model
|
| 773 |
|
| 774 |
+
if hf_loading or (resume_checkpoint and Path(resume_checkpoint).exists()):
|
| 775 |
if dist_util.get_rank() == 0:
|
| 776 |
# ! rank 0 return will cause all other ranks to hang
|
|
|
|
|
|
|
| 777 |
map_location = {
|
| 778 |
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
|
| 779 |
} # configure map_location properly
|
|
|
|
| 780 |
logger.log(f'mark {model_name} loading ')
|
| 781 |
+
|
| 782 |
+
if hf_loading:
|
| 783 |
+
logger.log(
|
| 784 |
+
f"loading model from huggingface: yslan/LN3Diff/{self.resume_checkpoint}...")
|
| 785 |
+
else:
|
| 786 |
+
logger.log(
|
| 787 |
+
f"loading model from checkpoint: {resume_checkpoint}...")
|
| 788 |
+
|
| 789 |
+
if hf_loading:
|
| 790 |
+
model_path = hf_hub_download(repo_id="yslan/LN3Diff",
|
| 791 |
+
filename=self.resume_checkpoint)
|
| 792 |
+
resume_state_dict = load_file(model_path)
|
| 793 |
+
else:
|
| 794 |
+
resume_state_dict = dist_util.load_state_dict(
|
| 795 |
+
resume_checkpoint, map_location=map_location)
|
| 796 |
+
|
| 797 |
logger.log(f'mark {model_name} loading finished')
|
| 798 |
|
| 799 |
model_state_dict = model.state_dict()
|