ford442 commited on
Commit
e81ccc6
·
verified ·
1 Parent(s): 80272ee

Update audiocraft/models/loaders.py

Browse files
Files changed (1) hide show
  1. audiocraft/models/loaders.py +4 -0
audiocraft/models/loaders.py CHANGED
@@ -117,6 +117,10 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', depth='f
117
  _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
118
  _delete_param(cfg, 'conditioners.args.drop_desc_p')
119
  model = builders.get_lm_model(cfg)
 
 
 
 
120
  model.load_state_dict(pkg['best_state'])
121
  model.eval()
122
  model.cfg = cfg
 
117
  _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
118
  _delete_param(cfg, 'conditioners.args.drop_desc_p')
119
  model = builders.get_lm_model(cfg)
120
+ if depth=='bfloat16':
121
+ model = model.to(torch.bfloat16)
122
+ if depth=='float16':
123
+ model = model.to(torch.float16)
124
  model.load_state_dict(pkg['best_state'])
125
  model.eval()
126
  model.cfg = cfg