Mikhael Johanes commited on
Commit
75d2880
·
1 Parent(s): d491737

fix deserialize weights in cpu only machine

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -28,7 +28,7 @@ cfg = load_params(cfg_path)
28
  def get_model(index):
29
  TransformerPath = model_paths[index]
30
  transformer = VQVAETransformer(cfg)
31
- transformer.load_state_dict(torch.load(TransformerPath))
32
  transformer = transformer.to(device)
33
  transformer.eval()
34
  return transformer
 
28
  def get_model(index):
29
  TransformerPath = model_paths[index]
30
  transformer = VQVAETransformer(cfg)
31
+ transformer.load_state_dict(torch.load(TransformerPath, map_location=device))
32
  transformer = transformer.to(device)
33
  transformer.eval()
34
  return transformer