JiantaoLin commited on
Commit
e945abc
·
1 Parent(s): 4a35d51
Files changed (1) hide show
  1. pipeline/kiss3d_wrapper.py +13 -13
pipeline/kiss3d_wrapper.py CHANGED
@@ -137,20 +137,20 @@ def init_wrapper_from_config(config_path):
137
  caption_model = None
138
 
139
  # load reconstruction model
140
- # logger.info('==> Loading reconstruction model ...')
141
- # recon_device = config_['reconstruction'].get('device', 'cpu')
142
- # recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
143
- # recon_model = instantiate_from_config(recon_model_config.model_config)
144
- # # load recon model checkpoint
145
- # model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
146
- # state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
147
- # state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
148
- # recon_model.load_state_dict(state_dict, strict=True)
149
- # recon_model.to(recon_device)
150
- # recon_model.eval()
151
  # logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
152
- recon_model = None
153
- recon_model_config = None
154
  # load llm
155
  llm_configs = config_.get('llm', None)
156
  if llm_configs is not None and False:
 
137
  caption_model = None
138
 
139
  # load reconstruction model
140
+ logger.info('==> Loading reconstruction model ...')
141
+ recon_device = config_['reconstruction'].get('device', 'cpu')
142
+ recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
143
+ recon_model = instantiate_from_config(recon_model_config.model_config)
144
+ # load recon model checkpoint
145
+ model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
146
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
147
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
148
+ recon_model.load_state_dict(state_dict, strict=True)
149
+ recon_model.to(recon_device)
150
+ recon_model.eval()
151
  # logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
152
+ # recon_model = None
153
+ # recon_model_config = None
154
  # load llm
155
  llm_configs = config_.get('llm', None)
156
  if llm_configs is not None and False: