Spaces:
Runtime error
Runtime error
Stop downloading models
Browse files
app.py
CHANGED
|
@@ -5,7 +5,6 @@ import gradio as gr
|
|
| 5 |
import numpy as np
|
| 6 |
import os
|
| 7 |
import torch
|
| 8 |
-
import subprocess
|
| 9 |
import output
|
| 10 |
|
| 11 |
from rdkit import Chem
|
|
@@ -53,24 +52,15 @@ args = parser.parse_args()
|
|
| 53 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 54 |
print(f'Device: {device}')
|
| 55 |
os.makedirs("results", exist_ok=True)
|
| 56 |
-
os.makedirs("models", exist_ok=True)
|
| 57 |
|
| 58 |
size_gnn_path = 'models/geom_size_gnn.ckpt'
|
| 59 |
-
if not os.path.exists(size_gnn_path):
|
| 60 |
-
print('Downloading SizeGNN model...')
|
| 61 |
-
link = 'https://zenodo.org/record/7121300/files/geom_size_gnn.ckpt?download=1'
|
| 62 |
-
subprocess.run(f'wget {link} -O {size_gnn_path}', shell=True)
|
| 63 |
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
|
| 64 |
print('Loaded SizeGNN model')
|
| 65 |
|
| 66 |
|
| 67 |
diffusion_models = {}
|
| 68 |
for model_name, metadata in MODELS_METADATA.items():
|
| 69 |
-
link = metadata['link']
|
| 70 |
diffusion_path = metadata['path']
|
| 71 |
-
if not os.path.exists(diffusion_path):
|
| 72 |
-
print(f'Downloading {model_name}...')
|
| 73 |
-
subprocess.run(f'wget {link} -O {diffusion_path}', shell=True)
|
| 74 |
diffusion_models[model_name] = DDPM.load_from_checkpoint(diffusion_path, map_location=device).eval().to(device)
|
| 75 |
print(f'Loaded model {model_name}')
|
| 76 |
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import os
|
| 7 |
import torch
|
|
|
|
| 8 |
import output
|
| 9 |
|
| 10 |
from rdkit import Chem
|
|
|
|
| 52 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 53 |
print(f'Device: {device}')
|
| 54 |
os.makedirs("results", exist_ok=True)
|
|
|
|
| 55 |
|
| 56 |
size_gnn_path = 'models/geom_size_gnn.ckpt'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
size_nn = SizeClassifier.load_from_checkpoint('models/geom_size_gnn.ckpt', map_location=device).eval().to(device)
|
| 58 |
print('Loaded SizeGNN model')
|
| 59 |
|
| 60 |
|
| 61 |
diffusion_models = {}
|
| 62 |
for model_name, metadata in MODELS_METADATA.items():
|
|
|
|
| 63 |
diffusion_path = metadata['path']
|
|
|
|
|
|
|
|
|
|
| 64 |
diffusion_models[model_name] = DDPM.load_from_checkpoint(diffusion_path, map_location=device).eval().to(device)
|
| 65 |
print(f'Loaded model {model_name}')
|
| 66 |
|