import streamlit as st import requests import json from PIL import Image import base64 from io import BytesIO import pandas as pd from datetime import datetime import time import logging import os from typing import Dict, Any, Optional import re from reportlab.lib import colors from reportlab.lib.pagesizes import letter from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle import io import pytesseract # Tesseract OCR from dotenv import load_dotenv # For .env file import pytesseract # Set the Tesseract executable path pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract' # Linux/macOS # Load environment variables load_dotenv() # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Configuration and Constants class Config: GEMINI_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent" GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") # Load from .env MAX_RETRIES = 3 TIMEOUT = 30 MAX_IMAGE_SIZE = (1600, 1600) ALLOWED_MIME_TYPES = ["image/jpeg", "image/png"] MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB # Custom Exceptions class APIError(Exception): pass class ImageProcessingError(Exception): pass # Initialize session state def init_session_state(): if 'processing_history' not in st.session_state: st.session_state.processing_history = [] if 'current_document' not in st.session_state: st.session_state.current_document = None if 'pdf_history' not in st.session_state: st.session_state.pdf_history = [] # Page setup and styling def setup_page(): st.set_page_config( page_title="Medical Document Processor", page_icon="🏥", layout="wide", initial_sidebar_state="expanded" ) st.markdown(""" """, unsafe_allow_html=True) class PDFGenerator: @staticmethod def create_pdf(data: Dict[str, Any]) -> bytes: buffer = io.BytesIO() doc = SimpleDocTemplate(buffer, pagesize=letter) styles = getSampleStyleSheet() elements = [] # Title title_style = ParagraphStyle( 'CustomTitle', parent=styles['Heading1'], fontSize=24, spaceAfter=30 ) elements.append(Paragraph("Medical Document Report", title_style)) elements.append(Spacer(1, 20)) # Patient Information elements.append(Paragraph("Patient Information", styles['Heading2'])) patient_info = data.get('patient_info', {}) patient_data = [ ["Name:", patient_info.get('name', 'N/A')], ["Age:", patient_info.get('age', 'N/A')], ["Gender:", patient_info.get('gender', 'N/A')] ] patient_table = Table(patient_data, colWidths=[100, 300]) patient_table.setStyle(TableStyle([ ('GRID', (0, 0), (-1, -1), 1, colors.black), ('PADDING', (0, 0), (-1, -1), 6), ])) elements.append(patient_table) elements.append(Spacer(1, 20)) # Symptoms if data.get('symptoms'): elements.append(Paragraph("Symptoms", styles['Heading2'])) symptoms_text = "\n".join([f"- {symptom}" for symptom in data['symptoms']]) elements.append(Paragraph(symptoms_text, styles['BodyText'])) elements.append(Spacer(1, 20)) # Vital Signs if data.get('vital_signs'): elements.append(Paragraph("Vital Signs", styles['Heading2'])) vital_signs_data = [["Type", "Value"]] + [[vs['type'], vs['value']] for vs in data['vital_signs']] vital_signs_table = Table(vital_signs_data, colWidths=[200, 200]) vital_signs_table.setStyle(TableStyle([ ('GRID', (0, 0), (-1, -1), 1, colors.black), ('PADDING', (0, 0), (-1, -1), 6), ])) elements.append(vital_signs_table) elements.append(Spacer(1, 20)) # Medications if data.get('medications'): elements.append(Paragraph("Medications", styles['Heading2'])) meds_data = [["Name", "Dosage", "Instructions"]] + [ [med['name'], med['dosage'], med['instructions']] for med in data['medications'] ] meds_table = Table(meds_data, colWidths=[150, 100, 250]) meds_table.setStyle(TableStyle([ ('GRID', (0, 0), (-1, -1), 1, colors.black), ('PADDING', (0, 0), (-1, -1), 6), ])) elements.append(meds_table) elements.append(Spacer(1, 20)) doc.build(elements) return buffer.getvalue() class ImageProcessor: @staticmethod def validate_image(uploaded_file) -> tuple[bool, str]: try: if uploaded_file.size > Config.MAX_FILE_SIZE: return False, f"File size exceeds {Config.MAX_FILE_SIZE // (1024*1024)}MB limit" image = Image.open(uploaded_file) if image.format.upper() not in ['JPEG', 'PNG']: return False, "Unsupported image format. Please upload JPEG or PNG" return True, "Image validation successful" except Exception as e: logger.error(f"Image validation error: {str(e)}") return False, f"Image validation failed: {str(e)}" @staticmethod def preprocess_image(image: Image.Image) -> Image.Image: try: if image.mode != 'RGB': image = image.convert('RGB') if image.size[0] > Config.MAX_IMAGE_SIZE[0] or image.size[1] > Config.MAX_IMAGE_SIZE[1]: image.thumbnail(Config.MAX_IMAGE_SIZE, Image.Resampling.LANCZOS) return image except Exception as e: logger.error(f"Image preprocessing error: {str(e)}") raise ImageProcessingError(f"Failed to preprocess image: {str(e)}") class DocumentProcessor: def __init__(self): self.image_processor = ImageProcessor() def process_document(self, image: Image.Image) -> Dict[str, Any]: try: processed_image = self.image_processor.preprocess_image(image) # Extract text using Tesseract OCR tesseract_text = pytesseract.image_to_string(processed_image) # Extract text using Gemini API image_base64 = self.encode_image(processed_image) gemini_text = self.extract_text_with_gemini(image_base64) # Combine results from Tesseract and Gemini combined_text = self.combine_text_results(tesseract_text, gemini_text) results = { "document_type": self.classify_document(combined_text), "extracted_text": combined_text, "structured_data": None } if results["extracted_text"]: results["structured_data"] = self.extract_structured_data( results["extracted_text"] ) return results except Exception as e: logger.error(f"Document processing error: {str(e)}") raise @staticmethod def encode_image(image: Image.Image) -> str: buffered = BytesIO() image.save(buffered, format="JPEG", quality=95) return base64.b64encode(buffered.getvalue()).decode('utf-8') @staticmethod def extract_text_with_gemini(image_base64: str) -> str: prompt = """ Extract all visible text from this medical document. Include: - Headers and titles - Patient information - Medical data and values - Notes and annotations - Dates and timestamps Format the output in a clear, structured manner. """ response = GeminiAPI.call_api(prompt, image_base64) return response["candidates"][0]["content"]["parts"][0]["text"].strip() @staticmethod def combine_text_results(tesseract_text: str, gemini_text: str) -> str: # Combine results, prioritizing Gemini's output but adding Tesseract's output for completeness combined_text = f"Gemini Extracted Text:\n{gemini_text}\n\nTesseract Extracted Text:\n{tesseract_text}" return combined_text def classify_document(self, text: str) -> str: prompt = f""" Analyze this medical document and classify it into one of the following categories: - Lab Report - Patient Chart - Prescription - Imaging Report - Medical Certificate - Other (specify) Provide only the category name. Document Text: {text} """ response = GeminiAPI.call_api(prompt) return response["candidates"][0]["content"]["parts"][0]["text"].strip() def extract_structured_data(self, text: str) -> Dict[str, Any]: prompt = f""" Analyze this medical text and return a valid JSON object with the following structure: {{ "patient_info": {{ "name": "string", "age": "string", "gender": "string" }}, "symptoms": ["string"], "visits": [ {{ "date": "string", "reason": "string", "notes": "string" }} ], "vital_signs": [ {{ "type": "string", "value": "string" }} ], "medications": [ {{ "name": "string", "dosage": "string", "instructions": "string" }} ] }} Text to analyze: {text} """ response = GeminiAPI.call_api(prompt) structured_data = self.parse_json_response(response) # Predict gender if not mentioned if not structured_data['patient_info'].get('gender'): structured_data['patient_info']['gender'] = self.predict_gender( structured_data['patient_info'].get('name', '') ) # Correct medicine names structured_data['medications'] = [ self.correct_medicine_name(med) for med in structured_data.get('medications', []) ] # Improve symptoms extraction structured_data['symptoms'] = self.extract_symptoms(text) return structured_data @staticmethod def predict_gender(name: str) -> str: """Predict gender based on the patient's name.""" prompt = f""" Based on the name '{name}', predict the gender. Return only 'Male' or 'Female'. """ response = GeminiAPI.call_api(prompt) return response["candidates"][0]["content"]["parts"][0]["text"].strip() @staticmethod def correct_medicine_name(medication: Dict[str, Any]) -> Dict[str, Any]: """Correct the medicine name using a standardized approach.""" prompt = f""" Correct the following medicine name to its standard form: {medication['name']} Return only the corrected name. """ response = GeminiAPI.call_api(prompt) medication['name'] = response["candidates"][0]["content"]["parts"][0]["text"].strip() return medication @staticmethod def extract_symptoms(text: str) -> list[str]: """Extract symptoms from the text.""" prompt = f""" Extract all symptoms mentioned in the following medical text. Return only a list of symptoms: {text} """ response = GeminiAPI.call_api(prompt) symptoms = response["candidates"][0]["content"]["parts"][0]["text"].strip().split("\n") return [symptom.strip() for symptom in symptoms if symptom.strip()] @staticmethod def parse_json_response(response: Dict[str, Any]) -> Dict[str, Any]: try: response_text = response["candidates"][0]["content"]["parts"][0]["text"].strip() json_match = re.search(r'\{.*\}', response_text, re.DOTALL) if json_match: return json.loads(json_match.group()) raise ValueError("No JSON object found in response") except Exception as e: logger.error(f"JSON parsing error: {str(e)}") raise class EHRViewer: @staticmethod def display_ehr(data: Dict[str, Any]): st.markdown("## 📊 Electronic Health Record") with st.container(): st.markdown("### 👤 Patient Information") cols = st.columns(3) patient_info = data.get('patient_info', {}) cols[0].metric("Name", patient_info.get('name', 'N/A')) cols[1].metric("Age", patient_info.get('age', 'N/A')) cols[2].metric("Gender", patient_info.get('gender', 'N/A')) if data.get('symptoms'): st.markdown("### 🤒 Symptoms") symptoms_text = "\n".join([f"- {symptom}" for symptom in data['symptoms']]) st.markdown(symptoms_text) if data.get('vital_signs'): st.markdown("### 🫀 Vital Signs") vital_signs_df = pd.DataFrame(data['vital_signs']) st.dataframe(vital_signs_df, use_container_width=True) if data.get('medications'): st.markdown("### 💊 Medications") med_df = pd.DataFrame(data['medications']) st.dataframe(med_df, use_container_width=True) class GeminiAPI: @staticmethod def call_api(prompt: str, image_base64: Optional[str] = None) -> Dict[str, Any]: for attempt in range(Config.MAX_RETRIES): try: headers = {"Content-Type": "application/json"} parts = [{"text": prompt}] if image_base64: parts.append({ "inline_data": { "mime_type": "image/jpeg", "data": image_base64 } }) payload = {"contents": [{"parts": parts}]} response = requests.post( f"{Config.GEMINI_URL}?key={Config.GEMINI_API_KEY}", headers=headers, json=payload, timeout=Config.TIMEOUT ) response.raise_for_status() return response.json() except requests.exceptions.RequestException as e: if attempt == Config.MAX_RETRIES - 1: logger.error(f"API call failed after {Config.MAX_RETRIES} attempts: {str(e)}") raise APIError(f"API call failed: {str(e)}") time.sleep(2 ** attempt) def main(): init_session_state() setup_page() st.title("🏥 Advanced Medical Document Processor") st.markdown("Upload medical documents for automated processing and analysis.") # Sidebar with st.sidebar: st.header("📋 Processing History") if st.session_state.pdf_history: for idx, pdf_record in enumerate(st.session_state.pdf_history): with st.expander(f"Document {idx + 1}: {pdf_record['timestamp']}"): st.download_button( "📄 Download PDF", pdf_record['data'], file_name=pdf_record['filename'], mime="application/pdf", key=f"sidebar_{pdf_record['timestamp']}" ) else: st.info("No documents processed yet") # Main content uploaded_file = st.file_uploader( "Choose a medical document", type=['png', 'jpg', 'jpeg'], help="Upload a clear image of a medical document (max 5MB)" ) if uploaded_file: try: # Validate image is_valid, message = ImageProcessor.validate_image(uploaded_file) if not is_valid: st.error(message) return # Display image image = Image.open(uploaded_file) col1, col2 = st.columns([1, 2]) with col1: st.image(image, caption="Uploaded Document", use_column_width=True) # Process document if st.button("🔍 Process Document"): with st.spinner("Processing document..."): processor = DocumentProcessor() results = processor.process_document(image) # Generate PDF pdf_bytes = PDFGenerator.create_pdf(results['structured_data']) timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') pdf_filename = f"medical_report_{timestamp}.pdf" # Store in session state st.session_state.current_document = { 'timestamp': timestamp, 'results': results } st.session_state.processing_history.append( st.session_state.current_document ) st.session_state.pdf_history.append({ 'timestamp': timestamp, 'filename': pdf_filename, 'data': pdf_bytes }) # Display results with col2: st.success("Document processed successfully!") st.markdown(f"**Document Type:** {results['document_type']}") with st.expander("View Extracted Text"): st.text_area( "Raw Text", results['extracted_text'], height=200 ) # Display EHR View if results['structured_data']: EHRViewer.display_ehr(results['structured_data']) # Download options st.markdown("### 📥 Download Options") col1, col2 = st.columns(2) with col1: json_str = json.dumps(results['structured_data'], indent=2) st.download_button( "⬇️ Download JSON", json_str, file_name=f"medical_data_{timestamp}.json", mime="application/json" ) with col2: st.download_button( "📄 Download PDF Report", pdf_bytes, file_name=pdf_filename, mime="application/pdf" ) # Display PDF History st.markdown("### 📚 PDF History") if st.session_state.pdf_history: for pdf_record in st.session_state.pdf_history: col1, col2 = st.columns([3, 1]) with col1: st.write(f"Report from {pdf_record['timestamp']}") with col2: st.download_button( "📄 View PDF", pdf_record['data'], file_name=pdf_record['filename'], mime="application/pdf", key=f"history_{pdf_record['timestamp']}" ) else: st.info("No PDF history available") except Exception as e: st.error(f"An error occurred: {str(e)}") logger.exception("Error in main processing loop") if __name__ == "__main__": try: main() except Exception as e: st.error("An unexpected error occurred. Please try again later.") logger.exception("Unhandled exception in main application")