Update audiocraft/models/loaders.py
Browse files
audiocraft/models/loaders.py
CHANGED
@@ -117,6 +117,10 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', depth='f
|
|
117 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
118 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
119 |
model = builders.get_lm_model(cfg)
|
|
|
|
|
|
|
|
|
120 |
model.load_state_dict(pkg['best_state'])
|
121 |
model.eval()
|
122 |
model.cfg = cfg
|
|
|
117 |
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
118 |
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
119 |
model = builders.get_lm_model(cfg)
|
120 |
+
if depth=='bfloat16':
|
121 |
+
model = model.to(torch.bfloat16)
|
122 |
+
if depth=='float16':
|
123 |
+
model = model.to(torch.float16)
|
124 |
model.load_state_dict(pkg['best_state'])
|
125 |
model.eval()
|
126 |
model.cfg = cfg
|