File size: 4,746 Bytes
6a39465
 
21689c4
6a39465
21689c4
6a39465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d23f5f
6a39465
0d23f5f
6a39465
 
 
0d23f5f
6a39465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# app.py
import os
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import openai
from io import BytesIO
import json

# ─── 1. Configuration & Secrets ─────────────────────────────────────────────
openai.api_key = st.secrets["OPENAI_API_KEY"]  # or os.getenv("OPENAI_API_KEY")
st.set_page_config(
    page_title="MedSketchβ€―AI",
    layout="wide",
    initial_sidebar_state="expanded",
)

# ─── 2. Sidebar: Settings & Metadata ────────────────────────────────────────
st.sidebar.header("βš™οΈ Settings")
model_choice = st.sidebar.selectbox(
    "Model",
    ["GPT-4o (API)", "Stable Diffusion LoRA"],
    index=0
)
style_preset = st.sidebar.radio(
    "Preset Style",
    ["Anatomical Diagram", "H&E Histology", "IHC Pathology", "Custom"]
)
strength = st.sidebar.slider("Stylization Strength", 0.1, 1.0, 0.7)

st.sidebar.markdown("---")
st.sidebar.header("πŸ“‹ Metadata")
patient_id = st.sidebar.text_input("Patient / Case ID")
roi = st.sidebar.text_input("Region of Interest")
umls_code = st.sidebar.text_input("UMLS / SNOMED CT Code")

# ─── 3. Main: Prompt Input & Batch Generation ───────────────────────────────
st.title("πŸ–ΌοΈ MedSketchβ€―AI – Advanced Clinical Diagram Generator")

with st.expander("πŸ“ Enter Prompts (one per line for batch)"):
    raw = st.text_area(
        "Describe what you need:",
        placeholder=(
            "e.g. β€œGenerate a labeled cross‑section of the human heart with chamber names, valves, and flow arrows…”\n"
            "e.g. β€œProduce a stylized H&E stain of liver tissue highlighting portal triads…”"
        ),
        height=120
    )
prompts = [p.strip() for p in raw.splitlines() if p.strip()]

if st.button("πŸš€ Generate"):
    if not prompts:
        st.error("Please enter at least one prompt.")
    else:
        cols = st.columns(min(3, len(prompts)))
        for i, prompt in enumerate(prompts):
            with st.spinner(f"Rendering image {i+1}/{len(prompts)}…"):
                if model_choice == "GPT-4o (API)":
                    resp = openai.Image.create(
                        model="gpt-4o",
                        prompt=f"[{style_preset} | strength={strength}] {prompt}",
                        size="1024x1024"
                    )
                    img_data = requests.get(resp["data"][0]["url"]).content
                else:
                    # stub for Stable Diffusion LoRA
                    img_data = generate_sd_image(prompt, style=style_preset, strength=strength)
                img = Image.open(BytesIO(img_data))
            
            # Display + Download
            with cols[i]:
                st.image(img, use_column_width=True, caption=prompt)
                buf = BytesIO()
                img.save(buf, format="PNG")
                st.download_button(
                    label="⬇️ Download PNG",
                    data=buf.getvalue(),
                    file_name=f"medsketch_{i+1}.png",
                    mime="image/png"
                )
                
                # ─── Annotation Canvas ───────────────────────────
                st.markdown("**✏️ Annotate:**")
                canvas_res = st_canvas(
                    fill_color="rgba(255, 0, 0, 0.3)",  # annotation color
                    stroke_width=2,
                    background_image=img,
                    update_streamlit=True,
                    height=512,
                    width=512,
                    drawing_mode="freedraw",
                    key=f"canvas_{i}"
                )
                # Save annotations
                if canvas_res.json_data:
                    ann = canvas_res.json_data["objects"]
                    st.session_state.setdefault("annotations", {})[prompt] = ann

# ─── 4. History & Exports ───────────────────────────────────────────────────
if "annotations" in st.session_state:
    st.markdown("---")
    st.subheader("πŸ“š Session History & Annotations")
    for prm, objs in st.session_state["annotations"].items():
        st.markdown(f"**Prompt:** {prm}")
        st.json(objs)
    st.download_button(
        "⬇️ Export All Annotations (JSON)",
        data=json.dumps(st.session_state["annotations"], indent=2),
        file_name="medsketch_annotations.json",
        mime="application/json"
    )