mgbam's picture
Update app.py
6a39465 verified
raw
history blame
4.75 kB
# app.py
import os
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import openai
from io import BytesIO
import json
# ─── 1. Configuration & Secrets ─────────────────────────────────────────────
openai.api_key = st.secrets["OPENAI_API_KEY"] # or os.getenv("OPENAI_API_KEY")
st.set_page_config(
page_title="MedSketchβ€―AI",
layout="wide",
initial_sidebar_state="expanded",
)
# ─── 2. Sidebar: Settings & Metadata ────────────────────────────────────────
st.sidebar.header("βš™οΈ Settings")
model_choice = st.sidebar.selectbox(
"Model",
["GPT-4o (API)", "Stable Diffusion LoRA"],
index=0
)
style_preset = st.sidebar.radio(
"Preset Style",
["Anatomical Diagram", "H&E Histology", "IHC Pathology", "Custom"]
)
strength = st.sidebar.slider("Stylization Strength", 0.1, 1.0, 0.7)
st.sidebar.markdown("---")
st.sidebar.header("πŸ“‹ Metadata")
patient_id = st.sidebar.text_input("Patient / Case ID")
roi = st.sidebar.text_input("Region of Interest")
umls_code = st.sidebar.text_input("UMLS / SNOMED CT Code")
# ─── 3. Main: Prompt Input & Batch Generation ───────────────────────────────
st.title("πŸ–ΌοΈ MedSketchβ€―AI – Advanced Clinical Diagram Generator")
with st.expander("πŸ“ Enter Prompts (one per line for batch)"):
raw = st.text_area(
"Describe what you need:",
placeholder=(
"e.g. β€œGenerate a labeled cross‑section of the human heart with chamber names, valves, and flow arrows…”\n"
"e.g. β€œProduce a stylized H&E stain of liver tissue highlighting portal triads…”"
),
height=120
)
prompts = [p.strip() for p in raw.splitlines() if p.strip()]
if st.button("πŸš€ Generate"):
if not prompts:
st.error("Please enter at least one prompt.")
else:
cols = st.columns(min(3, len(prompts)))
for i, prompt in enumerate(prompts):
with st.spinner(f"Rendering image {i+1}/{len(prompts)}…"):
if model_choice == "GPT-4o (API)":
resp = openai.Image.create(
model="gpt-4o",
prompt=f"[{style_preset} | strength={strength}] {prompt}",
size="1024x1024"
)
img_data = requests.get(resp["data"][0]["url"]).content
else:
# stub for Stable Diffusion LoRA
img_data = generate_sd_image(prompt, style=style_preset, strength=strength)
img = Image.open(BytesIO(img_data))
# Display + Download
with cols[i]:
st.image(img, use_column_width=True, caption=prompt)
buf = BytesIO()
img.save(buf, format="PNG")
st.download_button(
label="⬇️ Download PNG",
data=buf.getvalue(),
file_name=f"medsketch_{i+1}.png",
mime="image/png"
)
# ─── Annotation Canvas ───────────────────────────
st.markdown("**✏️ Annotate:**")
canvas_res = st_canvas(
fill_color="rgba(255, 0, 0, 0.3)", # annotation color
stroke_width=2,
background_image=img,
update_streamlit=True,
height=512,
width=512,
drawing_mode="freedraw",
key=f"canvas_{i}"
)
# Save annotations
if canvas_res.json_data:
ann = canvas_res.json_data["objects"]
st.session_state.setdefault("annotations", {})[prompt] = ann
# ─── 4. History & Exports ───────────────────────────────────────────────────
if "annotations" in st.session_state:
st.markdown("---")
st.subheader("πŸ“š Session History & Annotations")
for prm, objs in st.session_state["annotations"].items():
st.markdown(f"**Prompt:** {prm}")
st.json(objs)
st.download_button(
"⬇️ Export All Annotations (JSON)",
data=json.dumps(st.session_state["annotations"], indent=2),
file_name="medsketch_annotations.json",
mime="application/json"
)