MooseML commited on
Commit
d072ab4
·
1 Parent(s): 5331c58

increased upload limit, moved uploaded outside the form

Browse files
Files changed (2) hide show
  1. Dockerfile +5 -9
  2. app.py +43 -30
Dockerfile CHANGED
@@ -1,39 +1,35 @@
1
- # Dockerfile: Streamlit/RDKit/PyG (Hugging Face Spaces)
2
  FROM python:3.10-slim
3
 
4
- # OS libs for RDKit drawing
5
  RUN apt-get update && apt-get install -y --no-install-recommends \
6
  build-essential libxrender1 libxext6 libsm6 libx11-6 \
7
  libglib2.0-0 libfreetype6 libpng-dev wget && \
8
  rm -rf /var/lib/apt/lists/*
9
 
10
- # Non‑root user
11
  RUN useradd -m appuser
12
 
13
- # Python packages
14
  RUN pip install --no-cache-dir --upgrade pip && \
15
  pip install --no-cache-dir \
16
  streamlit==1.45.0 rdkit-pypi==2022.9.5 pandas==2.2.3 \
17
  numpy==1.26.4 torch==2.2.0 torch-geometric==2.5.2 \
18
  ogb==1.3.6 pillow==10.3.0
19
 
20
- # Workdir and code
21
  WORKDIR /app
22
  COPY . .
23
 
24
- # Writable dirs with 775 perms
25
  RUN install -d -o appuser -g appuser -m 775 /data /tmp/streamlit
26
 
27
- # Environment
28
  ENV DB_DIR=/data \
29
  STREAMLIT_SERVER_HEADLESS=true \
30
  STREAMLIT_SERVER_ADDRESS=0.0.0.0 \
31
  STREAMLIT_SERVER_PORT=7860 \
32
  STREAMLIT_TELEMETRY_DISABLED=true \
33
  STREAMLIT_BROWSER_GATHER_USAGE_STATS=false \
34
- STREAMLIT_SERVER_MAX_UPLOAD_SIZE=50
35
 
36
  EXPOSE 7860
37
-
38
  USER appuser
39
  CMD ["streamlit", "run", "app.py"]
 
 
1
  FROM python:3.10-slim
2
 
3
+ # OS libs for RDKit drawing
4
  RUN apt-get update && apt-get install -y --no-install-recommends \
5
  build-essential libxrender1 libxext6 libsm6 libx11-6 \
6
  libglib2.0-0 libfreetype6 libpng-dev wget && \
7
  rm -rf /var/lib/apt/lists/*
8
 
9
+ # Non‑root user
10
  RUN useradd -m appuser
11
 
12
+ # Python deps
13
  RUN pip install --no-cache-dir --upgrade pip && \
14
  pip install --no-cache-dir \
15
  streamlit==1.45.0 rdkit-pypi==2022.9.5 pandas==2.2.3 \
16
  numpy==1.26.4 torch==2.2.0 torch-geometric==2.5.2 \
17
  ogb==1.3.6 pillow==10.3.0
18
 
 
19
  WORKDIR /app
20
  COPY . .
21
 
22
+ # Writable dirs, owned by appuser, perms 775
23
  RUN install -d -o appuser -g appuser -m 775 /data /tmp/streamlit
24
 
 
25
  ENV DB_DIR=/data \
26
  STREAMLIT_SERVER_HEADLESS=true \
27
  STREAMLIT_SERVER_ADDRESS=0.0.0.0 \
28
  STREAMLIT_SERVER_PORT=7860 \
29
  STREAMLIT_TELEMETRY_DISABLED=true \
30
  STREAMLIT_BROWSER_GATHER_USAGE_STATS=false \
31
+ STREAMLIT_SERVER_MAX_UPLOAD_SIZE=200
32
 
33
  EXPOSE 7860
 
34
  USER appuser
35
  CMD ["streamlit", "run", "app.py"]
app.py CHANGED
@@ -12,16 +12,17 @@ from torch_geometric.loader import DataLoader
12
  from model import load_model
13
  from utils import smiles_to_data
14
 
15
- # Config
16
  DEVICE, RDKIT_DIM, MODEL_PATH, MAX_DISPLAY = "cpu", 6, "best_hybridgnn.pt", 20
17
 
18
- # Model & DB (cached)
19
  @st.cache_resource
20
  def get_model():
21
  return load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
22
 
23
  model = get_model()
24
 
 
25
  DB_DIR = pathlib.Path(os.getenv("DB_DIR", "/tmp"))
26
  DB_DIR.mkdir(parents=True, exist_ok=True)
27
 
@@ -39,7 +40,7 @@ def init_db():
39
  conn = init_db()
40
  cursor = conn.cursor()
41
 
42
- # debug and info panel
43
  with st.sidebar.expander("Info & Env", expanded=False):
44
  st.write(f"Python {sys.version.split()[0]}")
45
  st.write(f"Temp dir: `{tempfile.gettempdir()}` "
@@ -47,7 +48,7 @@ with st.sidebar.expander("Info & Env", expanded=False):
47
  if "csv_bytes" in st.session_state:
48
  st.write(f"Last upload: **{len(st.session_state['csv_bytes'])/1024:.1f} KB**")
49
 
50
- # Header
51
  st.title("HOMO-LUMO Gap Predictor")
52
  st.markdown("""
53
  This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
@@ -60,52 +61,63 @@ This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph N
60
  - The app will display predictions and molecule images (up to 20 shown at once).
61
  """)
62
 
63
- # File uploader (outside form)
64
  csv_file = st.file_uploader("CSV with SMILES", type=["csv"])
65
  if csv_file is not None:
66
- st.session_state["csv_bytes"] = csv_file.getvalue()
67
 
68
- # Input form
69
  smiles_list = []
70
  with st.form("main_form"):
71
- smiles_text = st.text_area("…or paste SMILES (comma/newline separated)",
72
- placeholder="CC(=O)Oc1ccccc1C(=O)O",
73
- height=120)
 
 
74
  run = st.form_submit_button("Run Prediction")
75
 
76
- # Parse input
77
  if run:
78
- if "csv_bytes" in st.session_state:
 
 
 
 
 
 
 
79
  try:
80
- df = pd.read_csv(StringIO(st.session_state["csv_bytes"].decode("utf-8")), comment="#")
81
- col = df.columns[0] if df.shape[1] == 1 else next((c for c in df.columns if c.lower() == "smiles"), None)
 
 
 
 
 
82
  if col is None:
83
- st.error("CSV needs one column or a 'SMILES' column")
84
  else:
85
  smiles_list = df[col].dropna().astype(str).tolist()
86
  st.success(f"{len(smiles_list)} SMILES loaded from CSV")
87
  except Exception as e:
88
  st.error(f"CSV error: {e}")
89
 
90
- elif smiles_text.strip():
91
- smiles_list = [s.strip() for s in smiles_text.replace("\n", ",").split(",") if s.strip()]
92
- st.success(f"{len(smiles_list)} SMILES parsed from textbox")
93
  else:
94
- st.warning("No input provided")
95
 
96
- # Inference & display
97
  if smiles_list:
98
  data_list = smiles_to_data(smiles_list, device=DEVICE)
99
  valid = [(s, d) for s, d in zip(smiles_list, data_list) if d is not None]
100
 
101
  if not valid:
102
- st.warning("No valid molecules")
103
  else:
104
  vsmi, vdata = zip(*valid)
105
  preds = []
106
  for batch in DataLoader(vdata, batch_size=64):
107
  with torch.no_grad():
108
- preds.extend(get_model()(batch.to(DEVICE)).view(-1).cpu().numpy().tolist())
109
 
110
  st.subheader(f"Results (first {MAX_DISPLAY})")
111
  for i, (smi, pred) in enumerate(zip(vsmi, preds)):
@@ -115,17 +127,18 @@ if smiles_list:
115
  mol = Chem.MolFromSmiles(smi)
116
  if mol:
117
  st.image(Draw.MolToImage(mol, size=(250, 250)))
118
- st.write(f"`{smi}` → **{pred:.4f} eV**")
119
 
120
  cursor.execute(
121
  "INSERT INTO predictions(smiles, prediction, timestamp) VALUES (?,?,?)",
122
- (smi, float(pred), datetime.now().isoformat()),
123
  )
124
  conn.commit()
125
 
126
- st.download_button("Download CSV",
127
- pd.DataFrame(
128
- {"SMILES": vsmi, "Gap (eV)": [round(p, 4) for p in preds]}
129
- ).to_csv(index=False).encode(),
130
- "homolumo_predictions.csv",
131
- "text/csv")
 
 
12
  from model import load_model
13
  from utils import smiles_to_data
14
 
15
+ # configuration
16
  DEVICE, RDKIT_DIM, MODEL_PATH, MAX_DISPLAY = "cpu", 6, "best_hybridgnn.pt", 20
17
 
18
+ # heavy imports already done above; now Streamlit starts
19
  @st.cache_resource
20
  def get_model():
21
  return load_model(rdkit_dim=RDKIT_DIM, path=MODEL_PATH, device=DEVICE)
22
 
23
  model = get_model()
24
 
25
+ # SQLite (cached) — DB stored in /data or /tmp
26
  DB_DIR = pathlib.Path(os.getenv("DB_DIR", "/tmp"))
27
  DB_DIR.mkdir(parents=True, exist_ok=True)
28
 
 
40
  conn = init_db()
41
  cursor = conn.cursor()
42
 
43
+ # compact info panel
44
  with st.sidebar.expander("Info & Env", expanded=False):
45
  st.write(f"Python {sys.version.split()[0]}")
46
  st.write(f"Temp dir: `{tempfile.gettempdir()}` "
 
48
  if "csv_bytes" in st.session_state:
49
  st.write(f"Last upload: **{len(st.session_state['csv_bytes'])/1024:.1f} KB**")
50
 
51
+ # header and instructions (unchanged)
52
  st.title("HOMO-LUMO Gap Predictor")
53
  st.markdown("""
54
  This app predicts the HOMO-LUMO energy gap for molecules using a trained Graph Neural Network (GNN).
 
61
  - The app will display predictions and molecule images (up to 20 shown at once).
62
  """)
63
 
64
+ # uploader (outside the form)
65
  csv_file = st.file_uploader("CSV with SMILES", type=["csv"])
66
  if csv_file is not None:
67
+ st.session_state["csv_bytes"] = csv_file.getvalue() # cache raw bytes
68
 
69
+ # textarea and button
70
  smiles_list = []
71
  with st.form("main_form"):
72
+ smiles_text = st.text_area(
73
+ "…or paste SMILES (comma/newline separated)",
74
+ placeholder="CC(=O)Oc1ccccc1C(=O)O",
75
+ height=120,
76
+ )
77
  run = st.form_submit_button("Run Prediction")
78
 
79
+ # decide which input to use
80
  if run:
81
+ if smiles_text.strip(): # user typed → override CSV
82
+ smiles_list = [
83
+ s.strip() for s in smiles_text.replace("\n", ",").split(",") if s.strip()
84
+ ]
85
+ st.session_state.pop("csv_bytes", None) # forget previous upload
86
+ st.success(f"{len(smiles_list)} SMILES parsed from textbox")
87
+
88
+ elif "csv_bytes" in st.session_state: # CSV path
89
  try:
90
+ df = pd.read_csv(
91
+ StringIO(st.session_state["csv_bytes"].decode("utf-8")),
92
+ comment="#",
93
+ )
94
+ col = df.columns[0] if df.shape[1] == 1 else next(
95
+ (c for c in df.columns if c.lower() == "smiles"), None
96
+ )
97
  if col is None:
98
+ st.error("CSV needs one column or a 'SMILES' column.")
99
  else:
100
  smiles_list = df[col].dropna().astype(str).tolist()
101
  st.success(f"{len(smiles_list)} SMILES loaded from CSV")
102
  except Exception as e:
103
  st.error(f"CSV error: {e}")
104
 
 
 
 
105
  else:
106
+ st.warning("No input provided.")
107
 
108
+ # inference & display
109
  if smiles_list:
110
  data_list = smiles_to_data(smiles_list, device=DEVICE)
111
  valid = [(s, d) for s, d in zip(smiles_list, data_list) if d is not None]
112
 
113
  if not valid:
114
+ st.warning("No valid molecules.")
115
  else:
116
  vsmi, vdata = zip(*valid)
117
  preds = []
118
  for batch in DataLoader(vdata, batch_size=64):
119
  with torch.no_grad():
120
+ preds.extend(model(batch.to(DEVICE)).view(-1).cpu().numpy().tolist())
121
 
122
  st.subheader(f"Results (first {MAX_DISPLAY})")
123
  for i, (smi, pred) in enumerate(zip(vsmi, preds)):
 
127
  mol = Chem.MolFromSmiles(smi)
128
  if mol:
129
  st.image(Draw.MolToImage(mol, size=(250, 250)))
130
+ st.write(f"`{smi}` → **{pred:.4f}eV**")
131
 
132
  cursor.execute(
133
  "INSERT INTO predictions(smiles, prediction, timestamp) VALUES (?,?,?)",
134
+ (smi, float(pred), datetime.utcnow().isoformat()),
135
  )
136
  conn.commit()
137
 
138
+ st.download_button(
139
+ "Download CSV",
140
+ pd.DataFrame({"SMILES": vsmi, "Gap (eV)": [round(p, 4) for p in preds]})
141
+ .to_csv(index=False).encode(),
142
+ "homolumo_predictions.csv",
143
+ "text/csv",
144
+ )