katrihiovain commited on
Commit
2b63853
·
1 Parent(s): 073f6eb

pudated models.py line 240

Browse files
Files changed (1) hide show
  1. models.py +2 -1
models.py CHANGED
@@ -236,7 +236,8 @@ def load_model_from_ckpt(checkpoint_data, model, key='state_dict'):
236
  def load_and_setup_model(model_name, parser, checkpoint, amp, device,
237
  unk_args=[], forward_is_infer=False, jitable=False):
238
  if checkpoint is not None:
239
- ckpt_data = torch.load(checkpoint)
 
240
  print(f'{model_name}: Loading {checkpoint}...')
241
  ckpt_config = ckpt_data.get('config')
242
  if ckpt_config is None:
 
236
  def load_and_setup_model(model_name, parser, checkpoint, amp, device,
237
  unk_args=[], forward_is_infer=False, jitable=False):
238
  if checkpoint is not None:
239
+ #ckpt_data = torch.load(checkpoint)
240
+ ckpt_data = torch.load(checkpoint, map_location=device)
241
  print(f'{model_name}: Loading {checkpoint}...')
242
  ckpt_config = ckpt_data.get('config')
243
  if ckpt_config is None: