Spaces:
Running
Running
increased upload limit, moved uploaded outside the form
Browse files- Dockerfile +5 -9
- app.py +43 -30
Dockerfile
CHANGED
@@ -1,39 +1,35 @@
|
|
1 |
-
# Dockerfile: Streamlit/RDKit/PyG (Hugging Face Spaces)
|
2 |
FROM python:3.10-slim
|
3 |
|
4 |
-
#
|
5 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
6 |
build-essential libxrender1 libxext6 libsm6 libx11-6 \
|
7 |
libglib2.0-0 libfreetype6 libpng-dev wget && \
|
8 |
rm -rf /var/lib/apt/lists/*
|
9 |
|
10 |
-
#
|
11 |
RUN useradd -m appuser
|
12 |
|
13 |
-
#
|
14 |
RUN pip install --no-cache-dir --upgrade pip && \
|
15 |
pip install --no-cache-dir \
|
16 |
streamlit==1.45.0 rdkit-pypi==2022.9.5 pandas==2.2.3 \
|
17 |
numpy==1.26.4 torch==2.2.0 torch-geometric==2.5.2 \
|
18 |
ogb==1.3.6 pillow==10.3.0
|
19 |
|
20 |
-
# Workdir and code
|
21 |
WORKDIR /app
|
22 |
COPY . .
|
23 |
|
24 |
-
# Writable dirs
|
25 |
RUN install -d -o appuser -g appuser -m 775 /data /tmp/streamlit
|
26 |
|
27 |
-
# Environment
|
28 |
ENV DB_DIR=/data \
|
29 |
STREAMLIT_SERVER_HEADLESS=true \
|
30 |
STREAMLIT_SERVER_ADDRESS=0.0.0.0 \
|
31 |
STREAMLIT_SERVER_PORT=7860 \
|
32 |
STREAMLIT_TELEMETRY_DISABLED=true \
|
33 |
STREAMLIT_BROWSER_GATHER_USAGE_STATS=false \
|
34 |
-
STREAMLIT_SERVER_MAX_UPLOAD_SIZE=
|
35 |
|
36 |
EXPOSE 7860
|
37 |
-
|
38 |
USER appuser
|
39 |
CMD ["streamlit", "run", "app.py"]
|
|
|
|
|
1 |
FROM python:3.10-slim
|
2 |
|
3 |
+
# OS libs for RDKit drawing
|
4 |
RUN apt-get update && apt-get install -y --no-install-recommends \
|
5 |
build-essential libxrender1 libxext6 libsm6 libx11-6 \
|
6 |
libglib2.0-0 libfreetype6 libpng-dev wget && \
|
7 |
rm -rf /var/lib/apt/lists/*
|
8 |
|
9 |
+
# Non‑root user
|
10 |
RUN useradd -m appuser
|
11 |
|
12 |
+
# Python deps
|
13 |
RUN pip install --no-cache-dir --upgrade pip && \
|
14 |
pip install --no-cache-dir \
|
15 |
streamlit==1.45.0 rdkit-pypi==2022.9.5 pandas==2.2.3 \
|
16 |
numpy==1.26.4 torch==2.2.0 torch-geometric==2.5.2 \
|
17 |
ogb==1.3.6 pillow==10.3.0
|
18 |
|
|
|
19 |
WORKDIR /app
|
20 |
COPY . .
|
21 |
|
22 |
+
# Writable dirs, owned by appuser, perms 775
|
23 |
RUN install -d -o appuser -g appuser -m 775 /data /tmp/streamlit
|
24 |
|
|
|
25 |
ENV DB_DIR=/data \
|
26 |
STREAMLIT_SERVER_HEADLESS=true \
|
27 |
STREAMLIT_SERVER_ADDRESS=0.0.0.0 \
|
28 |
STREAMLIT_SERVER_PORT=7860 \
|
29 |
STREAMLIT_TELEMETRY_DISABLED=true \
|
30 |
STREAMLIT_BROWSER_GATHER_USAGE_STATS=false \
|
31 |
+
STREAMLIT_SERVER_MAX_UPLOAD_SIZE=200
|
32 |
|
33 |
EXPOSE 7860
|
|
|
34 |
USER appuser
|
35 |
CMD ["streamlit", "run", "app.py"]
|
app.py
CHANGED
@@ -12,16 +12,17 @@ from torch_geometric.loader import DataLoader
|
|
12 |
from model import load_model
|
13 |
from utils import smiles_to_data
|
14 |
|
15 |
-
#
|
16 |
DEVICE, RDKIT_DIM, MODEL_PATH, MAX_DISPLAY = "cpu", 6, "best_hybridgnn.pt", 20
|
17 |
|
18 |
-
#
|
19 |
@st.cache_resource
|
20 |
def get_model():
|
21 |
return load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
|
22 |
|
23 |
model = get_model()
|
24 |
|
|
|
25 |
DB_DIR = pathlib.Path(os.getenv("DB_DIR", "/tmp"))
|
26 |
DB_DIR.mkdir(parents=True, exist_ok=True)
|
27 |
|
@@ -39,7 +40,7 @@ def init_db():
|
|
39 |
conn = init_db()
|
40 |
cursor = conn.cursor()
|
41 |
|
42 |
-
#
|
43 |
with st.sidebar.expander("Info & Env", expanded=False):
|
44 |
st.write(f"Python {sys.version.split()[0]}")
|
45 |
st.write(f"Temp dir: `{tempfile.gettempdir()}` "
|
@@ -47,7 +48,7 @@ with st.sidebar.expander("Info & Env", expanded=False):
|
|
47 |
if "csv_bytes" in st.session_state:
|
48 |
st.write(f"Last upload: **{len(st.session_state['csv_bytes'])/1024:.1f} KB**")
|
49 |
|
50 |
-
#
|
51 |
st.title("HOMO-LUMO Gap Predictor")
|
52 |
st.markdown("""
|
53 |
This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
|
@@ -60,52 +61,63 @@ This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph N
|
|
60 |
- The app will display predictions and molecule images (up to 20 shown at once).
|
61 |
""")
|
62 |
|
63 |
-
#
|
64 |
csv_file = st.file_uploader("CSV with SMILES", type=["csv"])
|
65 |
if csv_file is not None:
|
66 |
-
st.session_state["csv_bytes"] = csv_file.getvalue()
|
67 |
|
68 |
-
#
|
69 |
smiles_list = []
|
70 |
with st.form("main_form"):
|
71 |
-
smiles_text = st.text_area(
|
72 |
-
|
73 |
-
|
|
|
|
|
74 |
run = st.form_submit_button("Run Prediction")
|
75 |
|
76 |
-
#
|
77 |
if run:
|
78 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
try:
|
80 |
-
df = pd.read_csv(
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
82 |
if col is None:
|
83 |
-
st.error("CSV needs one column or a 'SMILES' column")
|
84 |
else:
|
85 |
smiles_list = df[col].dropna().astype(str).tolist()
|
86 |
st.success(f"{len(smiles_list)} SMILES loaded from CSV")
|
87 |
except Exception as e:
|
88 |
st.error(f"CSV error: {e}")
|
89 |
|
90 |
-
elif smiles_text.strip():
|
91 |
-
smiles_list = [s.strip() for s in smiles_text.replace("\n", ",").split(",") if s.strip()]
|
92 |
-
st.success(f"{len(smiles_list)} SMILES parsed from textbox")
|
93 |
else:
|
94 |
-
st.warning("No input provided")
|
95 |
|
96 |
-
#
|
97 |
if smiles_list:
|
98 |
data_list = smiles_to_data(smiles_list, device=DEVICE)
|
99 |
valid = [(s, d) for s, d in zip(smiles_list, data_list) if d is not None]
|
100 |
|
101 |
if not valid:
|
102 |
-
st.warning("No valid molecules")
|
103 |
else:
|
104 |
vsmi, vdata = zip(*valid)
|
105 |
preds = []
|
106 |
for batch in DataLoader(vdata, batch_size=64):
|
107 |
with torch.no_grad():
|
108 |
-
preds.extend(
|
109 |
|
110 |
st.subheader(f"Results (first {MAX_DISPLAY})")
|
111 |
for i, (smi, pred) in enumerate(zip(vsmi, preds)):
|
@@ -115,17 +127,18 @@ if smiles_list:
|
|
115 |
mol = Chem.MolFromSmiles(smi)
|
116 |
if mol:
|
117 |
st.image(Draw.MolToImage(mol, size=(250, 250)))
|
118 |
-
st.write(f"`{smi}` → **{pred:.4f}
|
119 |
|
120 |
cursor.execute(
|
121 |
"INSERT INTO predictions(smiles, prediction, timestamp) VALUES (?,?,?)",
|
122 |
-
(smi, float(pred), datetime.
|
123 |
)
|
124 |
conn.commit()
|
125 |
|
126 |
-
st.download_button(
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
12 |
from model import load_model
|
13 |
from utils import smiles_to_data
|
14 |
|
15 |
+
# configuration
|
16 |
DEVICE, RDKIT_DIM, MODEL_PATH, MAX_DISPLAY = "cpu", 6, "best_hybridgnn.pt", 20
|
17 |
|
18 |
+
# heavy imports already done above; now Streamlit starts
|
19 |
@st.cache_resource
|
20 |
def get_model():
|
21 |
return load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
|
22 |
|
23 |
model = get_model()
|
24 |
|
25 |
+
# SQLite (cached) — DB stored in /data or /tmp
|
26 |
DB_DIR = pathlib.Path(os.getenv("DB_DIR", "/tmp"))
|
27 |
DB_DIR.mkdir(parents=True, exist_ok=True)
|
28 |
|
|
|
40 |
conn = init_db()
|
41 |
cursor = conn.cursor()
|
42 |
|
43 |
+
# compact info panel
|
44 |
with st.sidebar.expander("Info & Env", expanded=False):
|
45 |
st.write(f"Python {sys.version.split()[0]}")
|
46 |
st.write(f"Temp dir: `{tempfile.gettempdir()}` "
|
|
|
48 |
if "csv_bytes" in st.session_state:
|
49 |
st.write(f"Last upload: **{len(st.session_state['csv_bytes'])/1024:.1f} KB**")
|
50 |
|
51 |
+
# header and instructions (unchanged)
|
52 |
st.title("HOMO-LUMO Gap Predictor")
|
53 |
st.markdown("""
|
54 |
This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
|
|
|
61 |
- The app will display predictions and molecule images (up to 20 shown at once).
|
62 |
""")
|
63 |
|
64 |
+
# uploader (outside the form)
|
65 |
csv_file = st.file_uploader("CSV with SMILES", type=["csv"])
|
66 |
if csv_file is not None:
|
67 |
+
st.session_state["csv_bytes"] = csv_file.getvalue() # cache raw bytes
|
68 |
|
69 |
+
# textarea and button
|
70 |
smiles_list = []
|
71 |
with st.form("main_form"):
|
72 |
+
smiles_text = st.text_area(
|
73 |
+
"…or paste SMILES (comma/newline separated)",
|
74 |
+
placeholder="CC(=O)Oc1ccccc1C(=O)O",
|
75 |
+
height=120,
|
76 |
+
)
|
77 |
run = st.form_submit_button("Run Prediction")
|
78 |
|
79 |
+
# decide which input to use
|
80 |
if run:
|
81 |
+
if smiles_text.strip(): # user typed → override CSV
|
82 |
+
smiles_list = [
|
83 |
+
s.strip() for s in smiles_text.replace("\n", ",").split(",") if s.strip()
|
84 |
+
]
|
85 |
+
st.session_state.pop("csv_bytes", None) # forget previous upload
|
86 |
+
st.success(f"{len(smiles_list)} SMILES parsed from textbox")
|
87 |
+
|
88 |
+
elif "csv_bytes" in st.session_state: # CSV path
|
89 |
try:
|
90 |
+
df = pd.read_csv(
|
91 |
+
StringIO(st.session_state["csv_bytes"].decode("utf-8")),
|
92 |
+
comment="#",
|
93 |
+
)
|
94 |
+
col = df.columns[0] if df.shape[1] == 1 else next(
|
95 |
+
(c for c in df.columns if c.lower() == "smiles"), None
|
96 |
+
)
|
97 |
if col is None:
|
98 |
+
st.error("CSV needs one column or a 'SMILES' column.")
|
99 |
else:
|
100 |
smiles_list = df[col].dropna().astype(str).tolist()
|
101 |
st.success(f"{len(smiles_list)} SMILES loaded from CSV")
|
102 |
except Exception as e:
|
103 |
st.error(f"CSV error: {e}")
|
104 |
|
|
|
|
|
|
|
105 |
else:
|
106 |
+
st.warning("No input provided.")
|
107 |
|
108 |
+
# inference & display
|
109 |
if smiles_list:
|
110 |
data_list = smiles_to_data(smiles_list, device=DEVICE)
|
111 |
valid = [(s, d) for s, d in zip(smiles_list, data_list) if d is not None]
|
112 |
|
113 |
if not valid:
|
114 |
+
st.warning("No valid molecules.")
|
115 |
else:
|
116 |
vsmi, vdata = zip(*valid)
|
117 |
preds = []
|
118 |
for batch in DataLoader(vdata, batch_size=64):
|
119 |
with torch.no_grad():
|
120 |
+
preds.extend(model(batch.to(DEVICE)).view(-1).cpu().numpy().tolist())
|
121 |
|
122 |
st.subheader(f"Results (first {MAX_DISPLAY})")
|
123 |
for i, (smi, pred) in enumerate(zip(vsmi, preds)):
|
|
|
127 |
mol = Chem.MolFromSmiles(smi)
|
128 |
if mol:
|
129 |
st.image(Draw.MolToImage(mol, size=(250, 250)))
|
130 |
+
st.write(f"`{smi}` → **{pred:.4f} eV**")
|
131 |
|
132 |
cursor.execute(
|
133 |
"INSERT INTO predictions(smiles, prediction, timestamp) VALUES (?,?,?)",
|
134 |
+
(smi, float(pred), datetime.utcnow().isoformat()),
|
135 |
)
|
136 |
conn.commit()
|
137 |
|
138 |
+
st.download_button(
|
139 |
+
"Download CSV",
|
140 |
+
pd.DataFrame({"SMILES": vsmi, "Gap (eV)": [round(p, 4) for p in preds]})
|
141 |
+
.to_csv(index=False).encode(),
|
142 |
+
"homolumo_predictions.csv",
|
143 |
+
"text/csv",
|
144 |
+
)
|