Mikhael Johanes
commited on
Commit
·
75d2880
1
Parent(s):
d491737
fix deserialize weights in cpu only machine
Browse files
app.py
CHANGED
@@ -28,7 +28,7 @@ cfg = load_params(cfg_path)
|
|
28 |
def get_model(index):
|
29 |
TransformerPath = model_paths[index]
|
30 |
transformer = VQVAETransformer(cfg)
|
31 |
-
transformer.load_state_dict(torch.load(TransformerPath))
|
32 |
transformer = transformer.to(device)
|
33 |
transformer.eval()
|
34 |
return transformer
|
|
|
28 |
def get_model(index):
|
29 |
TransformerPath = model_paths[index]
|
30 |
transformer = VQVAETransformer(cfg)
|
31 |
+
transformer.load_state_dict(torch.load(TransformerPath, map_location=device))
|
32 |
transformer = transformer.to(device)
|
33 |
transformer.eval()
|
34 |
return transformer
|