Spaces:
Sleeping
Sleeping
import streamlit as st | |
import requests | |
import json | |
from PIL import Image | |
import base64 | |
from io import BytesIO | |
import pandas as pd | |
from datetime import datetime | |
import time | |
import logging | |
import os | |
from typing import Dict, Any, Optional | |
import re | |
from reportlab.lib import colors | |
from reportlab.lib.pagesizes import letter | |
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle | |
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
import io | |
import pytesseract # Tesseract OCR | |
from dotenv import load_dotenv # For .env file | |
# Load environment variables | |
load_dotenv() | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Configuration and Constants | |
class Config: | |
GEMINI_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent" | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") # Load from .env | |
MAX_RETRIES = 3 | |
TIMEOUT = 30 | |
MAX_IMAGE_SIZE = (1600, 1600) | |
ALLOWED_MIME_TYPES = ["image/jpeg", "image/png"] | |
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB | |
# Custom Exceptions | |
class APIError(Exception): | |
pass | |
class ImageProcessingError(Exception): | |
pass | |
# Initialize session state | |
def init_session_state(): | |
if 'processing_history' not in st.session_state: | |
st.session_state.processing_history = [] | |
if 'current_document' not in st.session_state: | |
st.session_state.current_document = None | |
if 'pdf_history' not in st.session_state: | |
st.session_state.pdf_history = [] | |
# Page setup and styling | |
def setup_page(): | |
st.set_page_config( | |
page_title="Medical Document Processor", | |
page_icon="🏥", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
st.markdown(""" | |
<style> | |
.main {padding: 2rem; max-width: 1200px; margin: 0 auto;} | |
.stCard { | |
background-color: white; | |
padding: 2rem; | |
border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
margin: 1rem 0; | |
} | |
.header-container { | |
background-color: #f8f9fa; | |
padding: 2rem; | |
border-radius: 10px; | |
margin-bottom: 2rem; | |
} | |
.stButton>button { | |
background-color: #007bff; | |
color: white; | |
border: none; | |
padding: 0.5rem 1rem; | |
border-radius: 5px; | |
transition: all 0.3s ease; | |
} | |
.stButton>button:hover { | |
background-color: #0056b3; | |
transform: translateY(-2px); | |
} | |
.element-container {opacity: 1 !important;} | |
.pdf-history-item { | |
background-color: #f8f9fa; | |
padding: 1rem; | |
border-radius: 8px; | |
margin: 0.5rem 0; | |
border: 1px solid #dee2e6; | |
} | |
.metric-card { | |
background-color: #f8f9fa; | |
padding: 1rem; | |
border-radius: 8px; | |
border: 1px solid #dee2e6; | |
margin: 0.5rem 0; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
class PDFGenerator: | |
def create_pdf(data: Dict[str, Any]) -> bytes: | |
buffer = io.BytesIO() | |
doc = SimpleDocTemplate(buffer, pagesize=letter) | |
styles = getSampleStyleSheet() | |
elements = [] | |
# Title | |
title_style = ParagraphStyle( | |
'CustomTitle', | |
parent=styles['Heading1'], | |
fontSize=24, | |
spaceAfter=30 | |
) | |
elements.append(Paragraph("Medical Document Report", title_style)) | |
elements.append(Spacer(1, 20)) | |
# Patient Information | |
elements.append(Paragraph("Patient Information", styles['Heading2'])) | |
patient_info = data.get('patient_info', {}) | |
patient_data = [ | |
["Name:", patient_info.get('name', 'N/A')], | |
["Age:", patient_info.get('age', 'N/A')], | |
["Gender:", patient_info.get('gender', 'N/A')] | |
] | |
patient_table = Table(patient_data, colWidths=[100, 300]) | |
patient_table.setStyle(TableStyle([ | |
('GRID', (0, 0), (-1, -1), 1, colors.black), | |
('PADDING', (0, 0), (-1, -1), 6), | |
])) | |
elements.append(patient_table) | |
elements.append(Spacer(1, 20)) | |
# Symptoms | |
if data.get('symptoms'): | |
elements.append(Paragraph("Symptoms", styles['Heading2'])) | |
symptoms_text = "\n".join([f"- {symptom}" for symptom in data['symptoms']]) | |
elements.append(Paragraph(symptoms_text, styles['BodyText'])) | |
elements.append(Spacer(1, 20)) | |
# Vital Signs | |
if data.get('vital_signs'): | |
elements.append(Paragraph("Vital Signs", styles['Heading2'])) | |
vital_signs_data = [["Type", "Value"]] + [[vs['type'], vs['value']] for vs in data['vital_signs']] | |
vital_signs_table = Table(vital_signs_data, colWidths=[200, 200]) | |
vital_signs_table.setStyle(TableStyle([ | |
('GRID', (0, 0), (-1, -1), 1, colors.black), | |
('PADDING', (0, 0), (-1, -1), 6), | |
])) | |
elements.append(vital_signs_table) | |
elements.append(Spacer(1, 20)) | |
# Medications | |
if data.get('medications'): | |
elements.append(Paragraph("Medications", styles['Heading2'])) | |
meds_data = [["Name", "Dosage", "Instructions"]] + [ | |
[med['name'], med['dosage'], med['instructions']] for med in data['medications'] | |
] | |
meds_table = Table(meds_data, colWidths=[150, 100, 250]) | |
meds_table.setStyle(TableStyle([ | |
('GRID', (0, 0), (-1, -1), 1, colors.black), | |
('PADDING', (0, 0), (-1, -1), 6), | |
])) | |
elements.append(meds_table) | |
elements.append(Spacer(1, 20)) | |
doc.build(elements) | |
return buffer.getvalue() | |
class ImageProcessor: | |
def validate_image(uploaded_file) -> tuple[bool, str]: | |
try: | |
if uploaded_file.size > Config.MAX_FILE_SIZE: | |
return False, f"File size exceeds {Config.MAX_FILE_SIZE // (1024*1024)}MB limit" | |
image = Image.open(uploaded_file) | |
if image.format.upper() not in ['JPEG', 'PNG']: | |
return False, "Unsupported image format. Please upload JPEG or PNG" | |
return True, "Image validation successful" | |
except Exception as e: | |
logger.error(f"Image validation error: {str(e)}") | |
return False, f"Image validation failed: {str(e)}" | |
def preprocess_image(image: Image.Image) -> Image.Image: | |
try: | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
if image.size[0] > Config.MAX_IMAGE_SIZE[0] or image.size[1] > Config.MAX_IMAGE_SIZE[1]: | |
image.thumbnail(Config.MAX_IMAGE_SIZE, Image.Resampling.LANCZOS) | |
return image | |
except Exception as e: | |
logger.error(f"Image preprocessing error: {str(e)}") | |
raise ImageProcessingError(f"Failed to preprocess image: {str(e)}") | |
class DocumentProcessor: | |
def __init__(self): | |
self.image_processor = ImageProcessor() | |
def process_document(self, image: Image.Image) -> Dict[str, Any]: | |
try: | |
processed_image = self.image_processor.preprocess_image(image) | |
# Extract text using Tesseract OCR | |
tesseract_text = pytesseract.image_to_string(processed_image) | |
# Extract text using Gemini API | |
image_base64 = self.encode_image(processed_image) | |
gemini_text = self.extract_text_with_gemini(image_base64) | |
# Combine results from Tesseract and Gemini | |
combined_text = self.combine_text_results(tesseract_text, gemini_text) | |
results = { | |
"document_type": self.classify_document(combined_text), | |
"extracted_text": combined_text, | |
"structured_data": None | |
} | |
if results["extracted_text"]: | |
results["structured_data"] = self.extract_structured_data( | |
results["extracted_text"] | |
) | |
return results | |
except Exception as e: | |
logger.error(f"Document processing error: {str(e)}") | |
raise | |
def encode_image(image: Image.Image) -> str: | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG", quality=95) | |
return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
def extract_text_with_gemini(image_base64: str) -> str: | |
prompt = """ | |
Extract all visible text from this medical document. | |
Include: | |
- Headers and titles | |
- Patient information | |
- Medical data and values | |
- Notes and annotations | |
- Dates and timestamps | |
Format the output in a clear, structured manner. | |
""" | |
response = GeminiAPI.call_api(prompt, image_base64) | |
return response["candidates"][0]["content"]["parts"][0]["text"].strip() | |
def combine_text_results(tesseract_text: str, gemini_text: str) -> str: | |
# Combine results, prioritizing Gemini's output but adding Tesseract's output for completeness | |
combined_text = f"Gemini Extracted Text:\n{gemini_text}\n\nTesseract Extracted Text:\n{tesseract_text}" | |
return combined_text | |
def classify_document(self, text: str) -> str: | |
prompt = f""" | |
Analyze this medical document and classify it into one of the following categories: | |
- Lab Report | |
- Patient Chart | |
- Prescription | |
- Imaging Report | |
- Medical Certificate | |
- Other (specify) | |
Provide only the category name. | |
Document Text: | |
{text} | |
""" | |
response = GeminiAPI.call_api(prompt) | |
return response["candidates"][0]["content"]["parts"][0]["text"].strip() | |
def extract_structured_data(self, text: str) -> Dict[str, Any]: | |
prompt = f""" | |
Analyze this medical text and return a valid JSON object with the following structure: | |
{{ | |
"patient_info": {{ | |
"name": "string", | |
"age": "string", | |
"gender": "string" | |
}}, | |
"symptoms": ["string"], | |
"visits": [ | |
{{ | |
"date": "string", | |
"reason": "string", | |
"notes": "string" | |
}} | |
], | |
"vital_signs": [ | |
{{ | |
"type": "string", | |
"value": "string" | |
}} | |
], | |
"medications": [ | |
{{ | |
"name": "string", | |
"dosage": "string", | |
"instructions": "string" | |
}} | |
] | |
}} | |
Text to analyze: | |
{text} | |
""" | |
response = GeminiAPI.call_api(prompt) | |
structured_data = self.parse_json_response(response) | |
# Predict gender if not mentioned | |
if not structured_data['patient_info'].get('gender'): | |
structured_data['patient_info']['gender'] = self.predict_gender( | |
structured_data['patient_info'].get('name', '') | |
) | |
# Correct medicine names | |
structured_data['medications'] = [ | |
self.correct_medicine_name(med) for med in structured_data.get('medications', []) | |
] | |
# Improve symptoms extraction | |
structured_data['symptoms'] = self.extract_symptoms(text) | |
return structured_data | |
def predict_gender(name: str) -> str: | |
"""Predict gender based on the patient's name.""" | |
prompt = f""" | |
Based on the name '{name}', predict the gender. Return only 'Male' or 'Female'. | |
""" | |
response = GeminiAPI.call_api(prompt) | |
return response["candidates"][0]["content"]["parts"][0]["text"].strip() | |
def correct_medicine_name(medication: Dict[str, Any]) -> Dict[str, Any]: | |
"""Correct the medicine name using a standardized approach.""" | |
prompt = f""" | |
Correct the following medicine name to its standard form: | |
{medication['name']} | |
Return only the corrected name. | |
""" | |
response = GeminiAPI.call_api(prompt) | |
medication['name'] = response["candidates"][0]["content"]["parts"][0]["text"].strip() | |
return medication | |
def extract_symptoms(text: str) -> list[str]: | |
"""Extract symptoms from the text.""" | |
prompt = f""" | |
Extract all symptoms mentioned in the following medical text. Return only a list of symptoms: | |
{text} | |
""" | |
response = GeminiAPI.call_api(prompt) | |
symptoms = response["candidates"][0]["content"]["parts"][0]["text"].strip().split("\n") | |
return [symptom.strip() for symptom in symptoms if symptom.strip()] | |
def parse_json_response(response: Dict[str, Any]) -> Dict[str, Any]: | |
try: | |
response_text = response["candidates"][0]["content"]["parts"][0]["text"].strip() | |
json_match = re.search(r'\{.*\}', response_text, re.DOTALL) | |
if json_match: | |
return json.loads(json_match.group()) | |
raise ValueError("No JSON object found in response") | |
except Exception as e: | |
logger.error(f"JSON parsing error: {str(e)}") | |
raise | |
class EHRViewer: | |
def display_ehr(data: Dict[str, Any]): | |
st.markdown("## 📊 Electronic Health Record") | |
with st.container(): | |
st.markdown("### 👤 Patient Information") | |
cols = st.columns(3) | |
patient_info = data.get('patient_info', {}) | |
cols[0].metric("Name", patient_info.get('name', 'N/A')) | |
cols[1].metric("Age", patient_info.get('age', 'N/A')) | |
cols[2].metric("Gender", patient_info.get('gender', 'N/A')) | |
if data.get('symptoms'): | |
st.markdown("### 🤒 Symptoms") | |
symptoms_text = "\n".join([f"- {symptom}" for symptom in data['symptoms']]) | |
st.markdown(symptoms_text) | |
if data.get('vital_signs'): | |
st.markdown("### 🫀 Vital Signs") | |
vital_signs_df = pd.DataFrame(data['vital_signs']) | |
st.dataframe(vital_signs_df, use_container_width=True) | |
if data.get('medications'): | |
st.markdown("### 💊 Medications") | |
med_df = pd.DataFrame(data['medications']) | |
st.dataframe(med_df, use_container_width=True) | |
class GeminiAPI: | |
def call_api(prompt: str, image_base64: Optional[str] = None) -> Dict[str, Any]: | |
for attempt in range(Config.MAX_RETRIES): | |
try: | |
headers = {"Content-Type": "application/json"} | |
parts = [{"text": prompt}] | |
if image_base64: | |
parts.append({ | |
"inline_data": { | |
"mime_type": "image/jpeg", | |
"data": image_base64 | |
} | |
}) | |
payload = {"contents": [{"parts": parts}]} | |
response = requests.post( | |
f"{Config.GEMINI_URL}?key={Config.GEMINI_API_KEY}", | |
headers=headers, | |
json=payload, | |
timeout=Config.TIMEOUT | |
) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
if attempt == Config.MAX_RETRIES - 1: | |
logger.error(f"API call failed after {Config.MAX_RETRIES} attempts: {str(e)}") | |
raise APIError(f"API call failed: {str(e)}") | |
time.sleep(2 ** attempt) | |
def main(): | |
init_session_state() | |
setup_page() | |
st.title("🏥 Advanced Medical Document Processor") | |
st.markdown("Upload medical documents for automated processing and analysis.") | |
# Sidebar | |
with st.sidebar: | |
st.header("📋 Processing History") | |
if st.session_state.pdf_history: | |
for idx, pdf_record in enumerate(st.session_state.pdf_history): | |
with st.expander(f"Document {idx + 1}: {pdf_record['timestamp']}"): | |
st.download_button( | |
"📄 Download PDF", | |
pdf_record['data'], | |
file_name=pdf_record['filename'], | |
mime="application/pdf", | |
key=f"sidebar_{pdf_record['timestamp']}" | |
) | |
else: | |
st.info("No documents processed yet") | |
# Main content | |
uploaded_file = st.file_uploader( | |
"Choose a medical document", | |
type=['png', 'jpg', 'jpeg'], | |
help="Upload a clear image of a medical document (max 5MB)" | |
) | |
if uploaded_file: | |
try: | |
# Validate image | |
is_valid, message = ImageProcessor.validate_image(uploaded_file) | |
if not is_valid: | |
st.error(message) | |
return | |
# Display image | |
image = Image.open(uploaded_file) | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
st.image(image, caption="Uploaded Document", use_column_width=True) | |
# Process document | |
if st.button("🔍 Process Document"): | |
with st.spinner("Processing document..."): | |
processor = DocumentProcessor() | |
results = processor.process_document(image) | |
# Generate PDF | |
pdf_bytes = PDFGenerator.create_pdf(results['structured_data']) | |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
pdf_filename = f"medical_report_{timestamp}.pdf" | |
# Store in session state | |
st.session_state.current_document = { | |
'timestamp': timestamp, | |
'results': results | |
} | |
st.session_state.processing_history.append( | |
st.session_state.current_document | |
) | |
st.session_state.pdf_history.append({ | |
'timestamp': timestamp, | |
'filename': pdf_filename, | |
'data': pdf_bytes | |
}) | |
# Display results | |
with col2: | |
st.success("Document processed successfully!") | |
st.markdown(f"**Document Type:** {results['document_type']}") | |
with st.expander("View Extracted Text"): | |
st.text_area( | |
"Raw Text", | |
results['extracted_text'], | |
height=200 | |
) | |
# Display EHR View | |
if results['structured_data']: | |
EHRViewer.display_ehr(results['structured_data']) | |
# Download options | |
st.markdown("### 📥 Download Options") | |
col1, col2 = st.columns(2) | |
with col1: | |
json_str = json.dumps(results['structured_data'], indent=2) | |
st.download_button( | |
"⬇️ Download JSON", | |
json_str, | |
file_name=f"medical_data_{timestamp}.json", | |
mime="application/json" | |
) | |
with col2: | |
st.download_button( | |
"📄 Download PDF Report", | |
pdf_bytes, | |
file_name=pdf_filename, | |
mime="application/pdf" | |
) | |
# Display PDF History | |
st.markdown("### 📚 PDF History") | |
if st.session_state.pdf_history: | |
for pdf_record in st.session_state.pdf_history: | |
col1, col2 = st.columns([3, 1]) | |
with col1: | |
st.write(f"Report from {pdf_record['timestamp']}") | |
with col2: | |
st.download_button( | |
"📄 View PDF", | |
pdf_record['data'], | |
file_name=pdf_record['filename'], | |
mime="application/pdf", | |
key=f"history_{pdf_record['timestamp']}" | |
) | |
else: | |
st.info("No PDF history available") | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
logger.exception("Error in main processing loop") | |
if __name__ == "__main__": | |
try: | |
main() | |
except Exception as e: | |
st.error("An unexpected error occurred. Please try again later.") | |
logger.exception("Unhandled exception in main application") |