MooseML commited on
Commit
0714060
·
1 Parent(s): b8f152e

simplifying the file reading logic by passing the UploadedFile

Browse files
Files changed (1) hide show
  1. app.py +18 -25
app.py CHANGED
@@ -6,22 +6,21 @@ from datetime import datetime
6
  from rdkit import Chem
7
  from rdkit.Chem import Draw
8
  import os, pathlib
9
- from io import StringIO
10
  from model import load_model
11
  from utils import smiles_to_data
12
  from torch_geometric.loader import DataLoader
13
 
14
- # Config
15
  DEVICE = "cpu"
16
  RDKIT_DIM = 6
17
  MODEL_PATH = "best_hybridgnn.pt"
18
  MAX_DISPLAY = 10
19
 
20
- # Load Model
21
  model = load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
22
 
23
- # SQLite Setup
24
- DB_DIR = os.getenv("DB_DIR", "/tmp") # /data if you add a volume later
25
  pathlib.Path(DB_DIR).mkdir(parents=True, exist_ok=True)
26
 
27
  @st.cache_resource
@@ -43,7 +42,7 @@ def init_db():
43
  conn = init_db()
44
  cursor = conn.cursor()
45
 
46
- # Streamlit UI
47
  st.title("HOMO-LUMO Gap Predictor")
48
  st.markdown("""
49
  This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
@@ -56,10 +55,6 @@ This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph N
56
  - The app will display predictions and molecule images (up to 10 shown at once).
57
  """)
58
 
59
-
60
-
61
-
62
- # Single input form
63
  smiles_list = []
64
 
65
  with st.form("smiles_or_csv"):
@@ -78,9 +73,8 @@ if run:
78
  if csv_file is not None:
79
  try:
80
  csv_file.seek(0)
81
- df = pd.read_csv(StringIO(csv_file.getvalue().decode("utf‑8")), comment="#")
82
 
83
- # pick SMILES column
84
  if df.shape[1] == 1:
85
  smiles_col = df.iloc[:, 0]
86
  elif "smiles" in [c.lower() for c in df.columns]:
@@ -105,13 +99,11 @@ if run:
105
  else:
106
  st.warning("Please paste SMILES or upload a CSV before pressing *Run*.")
107
 
108
-
109
- # Run Inference
110
  if smiles_list:
111
  with st.spinner("Processing molecules..."):
112
  data_list = smiles_to_data(smiles_list, device=DEVICE)
113
 
114
- # Filter only valid molecules and keep aligned SMILES
115
  valid_pairs = [(smi, data) for smi, data in zip(smiles_list, data_list) if data is not None]
116
 
117
  if not valid_pairs:
@@ -127,7 +119,6 @@ if smiles_list:
127
  pred = model(batch).view(-1).cpu().numpy()
128
  predictions.extend(pred.tolist())
129
 
130
- # Display Results
131
  st.subheader(f"Predictions (showing up to {MAX_DISPLAY} molecules):")
132
 
133
  for i, (smi, pred) in enumerate(zip(valid_smiles, predictions)):
@@ -141,16 +132,18 @@ if smiles_list:
141
  st.write(f"**SMILES**: `{smi}`")
142
  st.write(f"**Predicted HOMO-LUMO Gap**: `{pred:.4f} eV`")
143
 
144
- # Log to SQLite
145
  cursor.execute("INSERT INTO predictions (smiles, prediction, timestamp) VALUES (?, ?, ?)",
146
  (smi, pred, str(datetime.now())))
147
  conn.commit()
148
 
149
- # Download Results
150
- result_df = pd.DataFrame({"SMILES": valid_smiles,
151
- "Predicted HOMO-LUMO Gap (eV)": [round(p, 4) for p in predictions]})
152
-
153
- st.download_button(label="Download Predictions as CSV",
154
- data=result_df.to_csv(index=False).encode('utf-8'),
155
- file_name="homolumo_predictions.csv",
156
- mime="text/csv")
 
 
 
 
6
  from rdkit import Chem
7
  from rdkit.Chem import Draw
8
  import os, pathlib
 
9
  from model import load_model
10
  from utils import smiles_to_data
11
  from torch_geometric.loader import DataLoader
12
 
13
+ # Config
14
  DEVICE = "cpu"
15
  RDKIT_DIM = 6
16
  MODEL_PATH = "best_hybridgnn.pt"
17
  MAX_DISPLAY = 10
18
 
19
+ # Load Model
20
  model = load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
21
 
22
+ # SQLite Setup
23
+ DB_DIR = os.getenv("DB_DIR", "/tmp")
24
  pathlib.Path(DB_DIR).mkdir(parents=True, exist_ok=True)
25
 
26
  @st.cache_resource
 
42
  conn = init_db()
43
  cursor = conn.cursor()
44
 
45
+ # Streamlit UI
46
  st.title("HOMO-LUMO Gap Predictor")
47
  st.markdown("""
48
  This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
 
55
  - The app will display predictions and molecule images (up to 10 shown at once).
56
  """)
57
 
 
 
 
 
58
  smiles_list = []
59
 
60
  with st.form("smiles_or_csv"):
 
73
  if csv_file is not None:
74
  try:
75
  csv_file.seek(0)
76
+ df = pd.read_csv(csv_file, comment="#")
77
 
 
78
  if df.shape[1] == 1:
79
  smiles_col = df.iloc[:, 0]
80
  elif "smiles" in [c.lower() for c in df.columns]:
 
99
  else:
100
  st.warning("Please paste SMILES or upload a CSV before pressing *Run*.")
101
 
102
+ # Run Inference
 
103
  if smiles_list:
104
  with st.spinner("Processing molecules..."):
105
  data_list = smiles_to_data(smiles_list, device=DEVICE)
106
 
 
107
  valid_pairs = [(smi, data) for smi, data in zip(smiles_list, data_list) if data is not None]
108
 
109
  if not valid_pairs:
 
119
  pred = model(batch).view(-1).cpu().numpy()
120
  predictions.extend(pred.tolist())
121
 
 
122
  st.subheader(f"Predictions (showing up to {MAX_DISPLAY} molecules):")
123
 
124
  for i, (smi, pred) in enumerate(zip(valid_smiles, predictions)):
 
132
  st.write(f"**SMILES**: `{smi}`")
133
  st.write(f"**Predicted HOMO-LUMO Gap**: `{pred:.4f} eV`")
134
 
 
135
  cursor.execute("INSERT INTO predictions (smiles, prediction, timestamp) VALUES (?, ?, ?)",
136
  (smi, pred, str(datetime.now())))
137
  conn.commit()
138
 
139
+ result_df = pd.DataFrame({
140
+ "SMILES": valid_smiles,
141
+ "Predicted HOMO-LUMO Gap (eV)": [round(p, 4) for p in predictions]
142
+ })
143
+
144
+ st.download_button(
145
+ label="Download Predictions as CSV",
146
+ data=result_df.to_csv(index=False).encode('utf-8'),
147
+ file_name="homolumo_predictions.csv",
148
+ mime="text/csv"
149
+ )