witspathology / app.py
IAMTFRMZA's picture
Update app.py
9836b6d verified
raw
history blame
13 kB
import streamlit as st
import os
import time
import re
import requests
from PIL import Image
from io import BytesIO
from openai import OpenAI
# ------------------ Page Config ------------------
st.set_page_config(page_title="AI Pathology Assistant", layout="wide", initial_sidebar_state="collapsed")
# ------------------ Authentication ------------------
VALID_USERS = {
"[email protected]": "Pass.123",
"[email protected]": "Pass.123",
"[email protected]": "Pass.123",
"[email protected]": "Pass.123",
}
def login():
st.title("πŸ” Login Required")
email = st.text_input("Email")
password = st.text_input("Password", type="password")
if st.button("Login"):
if VALID_USERS.get(email) == password:
st.session_state.authenticated = True
st.rerun()
else:
st.error("❌ Incorrect email or password.")
if "authenticated" not in st.session_state:
st.session_state.authenticated = False
if not st.session_state.authenticated:
login()
st.stop()
# ------------------ Load OpenAI ------------------
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
if not OPENAI_API_KEY:
st.error("❌ Missing OPENAI_API_KEY environment variable.")
st.stop()
client = OpenAI(api_key=OPENAI_API_KEY)
# ------------------ Assistant Setup ------------------
ASSISTANT_ID = "asst_jXDSjCG8LI4HEaFEcjFVq8KB"
# ------------------ State Initialization ------------------
for key in ["messages", "thread_id", "image_urls", "pending_prompt"]:
if key not in st.session_state:
st.session_state[key] = [] if key.endswith("s") else None
# ------------------ UI Tabs ------------------
tab1, tab2 = st.tabs(["πŸ’¬ Chat Assistant", "πŸ–ΌοΈ Visual Reference Search"])
# ------------------ Tab 1: Chat Assistant ------------------
with tab1:
with st.sidebar:
st.header("πŸ§ͺ Pathology Tools")
if st.button("🧹 Clear Chat"):
for k in ["messages", "thread_id", "image_urls", "pending_prompt"]:
st.session_state[k] = [] if k.endswith("s") else None
st.rerun()
show_image = st.toggle("πŸ“Έ Show Images", value=True)
keyword = st.text_input("Keyword Search", placeholder="e.g. mitosis, carcinoma")
if st.button("πŸ”Ž Search") and keyword:
st.session_state.pending_prompt = f"Find clauses or references related to: {keyword}"
section = st.text_input("Section Lookup", placeholder="e.g. Connective Tissue")
if section:
st.session_state.pending_prompt = f"Summarize or list key points from section: {section}"
action = st.selectbox("Common Pathology Queries", [
"Select an action...",
"List histological features of inflammation",
"Summarize features of carcinoma",
"List muscle types and features",
"Extract diagnostic markers",
"Summarize embryology stages"
])
if action != "Select an action...":
st.session_state.pending_prompt = action
chat_col, image_col = st.columns([2, 1])
with chat_col:
st.markdown("### πŸ’¬ Ask a Pathology-Specific Question")
user_input = st.chat_input("Example: What are features of squamous cell carcinoma?")
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
elif st.session_state.pending_prompt:
st.session_state.messages.append({"role": "user", "content": st.session_state.pending_prompt})
st.session_state.pending_prompt = None
if st.session_state.messages and st.session_state.messages[-1]["role"] == "user":
try:
if not st.session_state.thread_id:
st.session_state.thread_id = client.beta.threads.create().id
client.beta.threads.messages.create(
thread_id=st.session_state.thread_id,
role="user",
content=st.session_state.messages[-1]["content"]
)
run = client.beta.threads.runs.create(
thread_id=st.session_state.thread_id,
assistant_id=ASSISTANT_ID
)
with st.spinner("πŸ”¬ Analyzing..."):
while True:
status = client.beta.threads.runs.retrieve(
thread_id=st.session_state.thread_id, run_id=run.id
)
if status.status in ("completed", "failed", "cancelled"):
break
time.sleep(1)
if status.status == "completed":
responses = client.beta.threads.messages.list(
thread_id=st.session_state.thread_id
).data
for m in reversed(responses):
if m.role == "assistant":
reply = m.content[0].text.value.strip()
if not any(reply in msg["content"] or msg["content"] in reply
for msg in st.session_state.messages if msg["role"] == "assistant"):
st.session_state.messages.append({"role": "assistant", "content": reply})
# Extract image URLs from response
images = re.findall(
r'https://raw\.githubusercontent\.com/AndrewLORTech/witspathologai/main/[^\s\n"]+\.png',
reply
)
st.session_state.image_urls = images
break
else:
st.error("❌ Assistant failed to complete.")
st.rerun()
except Exception as e:
st.error(f"❌ Error: {e}")
# Display messages
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"], unsafe_allow_html=True)
# Display follow-up questions
if st.session_state.messages and st.session_state.messages[-1]["role"] == "assistant":
last = st.session_state.messages[-1]["content"]
if "Some Possible Questions:" in last:
# Filter only lines that are actual questions
all_lines = re.findall(r"[-β€’]\s*(.*)", last)
questions = [line for line in all_lines if line.strip().endswith("?")]
if questions:
st.markdown("#### πŸ’‘ Follow-Up Suggestions")
for q in questions:
if st.button(f"πŸ“Œ {q}"):
st.session_state.pending_prompt = q
st.rerun()
else:
st.markdown("#### πŸ’‘ No follow-up questions detected in the assistant's response.")
with image_col:
if show_image and st.session_state.image_urls:
st.markdown("### πŸ–ΌοΈ Images")
for url in st.session_state.image_urls:
try:
img = Image.open(BytesIO(requests.get(url, timeout=5).content))
st.image(img, caption=url.split("/")[-1], use_container_width=True)
except Exception:
st.warning(f"⚠️ Failed to load image: {url}")
# ------------------ Tab 2: Visual Reference Search ------------------
with tab2:
ASSISTANT_ID = "asst_9v09zgizdcuuhNdcFQpRo9RO"
if "image_thread_id" not in st.session_state:
st.session_state.image_thread_id = None
if "image_response" not in st.session_state:
st.session_state.image_response = None
if "image_results" not in st.session_state:
st.session_state.image_results = []
if "image_lightbox" not in st.session_state:
st.session_state.image_lightbox = None
image_input = st.chat_input("Ask for histology visual references (e.g. ovary histology, mitosis)")
if image_input:
st.session_state.image_response = None
st.session_state.image_results = []
st.session_state.image_lightbox = None
try:
if st.session_state.image_thread_id is None:
thread = client.beta.threads.create()
st.session_state.image_thread_id = thread.id
client.beta.threads.messages.create(
thread_id=st.session_state.image_thread_id,
role="user",
content=image_input
)
run = client.beta.threads.runs.create(
thread_id=st.session_state.image_thread_id,
assistant_id=ASSISTANT_ID
)
with st.spinner("πŸ”¬ Searching for histology references..."):
while True:
run_status = client.beta.threads.runs.retrieve(
thread_id=st.session_state.image_thread_id,
run_id=run.id
)
if run_status.status in ("completed", "failed", "cancelled"):
break
time.sleep(1)
if run_status.status == "completed":
messages = client.beta.threads.messages.list(thread_id=st.session_state.image_thread_id)
for msg in reversed(messages.data):
if msg.role == "assistant":
response_text = msg.content[0].text.value
st.session_state.image_response = response_text
# βœ… FINAL FIXED PARSER: Handles Image URL on next line
lines = response_text.splitlines()
image_urls = []
expecting_url = False
for line in lines:
line_clean = line.strip().replace("**", "")
if "Image URL:" in line_clean:
parts = line_clean.split("Image URL:")
if len(parts) > 1 and parts[1].strip().startswith("http"):
image_urls.append(parts[1].strip())
else:
expecting_url = True # Flag next line
elif expecting_url:
if line_clean.startswith("http"):
image_urls.append(line_clean)
expecting_url = False
st.session_state.image_results = [{"image": url} for url in image_urls]
if image_urls and not st.session_state.image_lightbox:
st.session_state.image_lightbox = image_urls[0]
break
except Exception as e:
st.error(f"❌ Visual Assistant Error: {e}")
text_col, image_col = st.columns([2, 1])
with text_col:
if st.session_state.image_response:
st.markdown("### 🧠 Assistant Response")
st.markdown(st.session_state.image_response, unsafe_allow_html=True)
# Optional: Debug
# st.code(st.session_state.image_response, language="markdown")
with image_col:
if st.session_state.image_results:
st.markdown("### πŸ–ΌοΈ Image Preview(s)")
for i, item in enumerate(st.session_state.image_results):
image_url = item.get("image")
if image_url:
try:
encoded_url = requests.utils.requote_uri(image_url)
r = requests.get(encoded_url, timeout=10)
r.raise_for_status()
img = Image.open(BytesIO(r.content))
st.image(img, caption=encoded_url.split("/")[-1], use_container_width=True)
if st.button("πŸ” View Full Image", key=f"full_img_{i}"):
st.session_state.image_lightbox = encoded_url
except Exception as e:
st.warning(f"⚠️ Could not load: {image_url}")
st.error(f"{e}")
else:
st.info("ℹ️ No image references found yet.")
if st.session_state.image_lightbox:
st.markdown("### πŸ”¬ Full Image View")
try:
img_url = st.session_state.image_lightbox
r = requests.get(img_url, timeout=10)
r.raise_for_status()
full_img = Image.open(BytesIO(r.content))
st.image(full_img, caption=img_url.split("/")[-1], use_container_width=True)
except Exception as e:
st.warning("⚠️ Could not load full image.")
st.error(str(e))
if st.button("❌ Close Preview"):
st.session_state.image_lightbox = None
st.rerun()