JiantaoLin commited on
Commit
9ba89f6
·
1 Parent(s): 4ffb78d
Files changed (1) hide show
  1. pipeline/kiss3d_wrapper.py +1 -1
pipeline/kiss3d_wrapper.py CHANGED
@@ -116,7 +116,7 @@ def init_wrapper_from_config(config_path):
116
  # unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen_19w.ckpt", repo_type="model", token=access_token)
117
  unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen.ckpt", repo_type="model", token=access_token)
118
  if unet_ckpt_path is not None:
119
- state_dict = torch.load(unet_ckpt_path, map_location='cpu')['state_dict']
120
  # state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
121
  multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
122
 
 
116
  # unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen_19w.ckpt", repo_type="model", token=access_token)
117
  unet_ckpt_path = hf_hub_download(repo_id="LTT/Kiss3DGen", filename="flexgen.ckpt", repo_type="model", token=access_token)
118
  if unet_ckpt_path is not None:
119
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
120
  # state_dict = {k[10:]: v for k, v in state_dict.items() if k.startswith('unet.unet.')}
121
  multiview_pipeline.unet.load_state_dict(state_dict, strict=True)
122