JiantaoLin
commited on
Commit
·
e945abc
1
Parent(s):
4a35d51
new
Browse files- pipeline/kiss3d_wrapper.py +13 -13
pipeline/kiss3d_wrapper.py
CHANGED
@@ -137,20 +137,20 @@ def init_wrapper_from_config(config_path):
|
|
137 |
caption_model = None
|
138 |
|
139 |
# load reconstruction model
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
#
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
# logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
|
152 |
-
recon_model = None
|
153 |
-
recon_model_config = None
|
154 |
# load llm
|
155 |
llm_configs = config_.get('llm', None)
|
156 |
if llm_configs is not None and False:
|
|
|
137 |
caption_model = None
|
138 |
|
139 |
# load reconstruction model
|
140 |
+
logger.info('==> Loading reconstruction model ...')
|
141 |
+
recon_device = config_['reconstruction'].get('device', 'cpu')
|
142 |
+
recon_model_config = OmegaConf.load(config_['reconstruction']['model_config'])
|
143 |
+
recon_model = instantiate_from_config(recon_model_config.model_config)
|
144 |
+
# load recon model checkpoint
|
145 |
+
model_ckpt_path = hf_hub_download(repo_id="LTT/PRM", filename="final_ckpt.ckpt", repo_type="model")
|
146 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
147 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.')}
|
148 |
+
recon_model.load_state_dict(state_dict, strict=True)
|
149 |
+
recon_model.to(recon_device)
|
150 |
+
recon_model.eval()
|
151 |
# logger.warning(f"GPU memory allocated after load reconstruction model on {recon_device}: {torch.cuda.memory_allocated(device=recon_device) / 1024**3} GB")
|
152 |
+
# recon_model = None
|
153 |
+
# recon_model_config = None
|
154 |
# load llm
|
155 |
llm_configs = config_.get('llm', None)
|
156 |
if llm_configs is not None and False:
|