MedDocDigitizer / app.py
rajsecrets0's picture
Update app.py
5bdbc3b verified
import streamlit as st
import requests
import json
from PIL import Image
import base64
from io import BytesIO
import pandas as pd
from datetime import datetime
import time
import logging
import os
from typing import Dict, Any, Optional
import re
from reportlab.lib import colors
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
import io
from dotenv import load_dotenv
import fitz # PyMuPDF for PDF processing
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Configuration and Constants
class Config:
GEMINI_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent"
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") # Load from .env
MAX_RETRIES = 3
TIMEOUT = 30
MAX_IMAGE_SIZE = (1600, 1600)
ALLOWED_MIME_TYPES = ["image/jpeg", "image/png", "application/pdf"]
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
# Custom Exceptions
class APIError(Exception):
pass
class ImageProcessingError(Exception):
pass
class PDFProcessingError(Exception):
pass
# Initialize session state
def init_session_state():
if 'processing_history' not in st.session_state:
st.session_state.processing_history = []
if 'current_document' not in st.session_state:
st.session_state.current_document = None
if 'pdf_history' not in st.session_state:
st.session_state.pdf_history = []
# Page setup and styling
def setup_page():
st.set_page_config(
page_title="Medical Document Processor",
page_icon="🏥",
layout="wide",
initial_sidebar_state="expanded"
)
st.markdown("""
<style>
.main {padding: 2rem; max-width: 1200px; margin: 0 auto;}
.stCard {
background-color: white;
padding: 2rem;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
margin: 1rem 0;
}
.header-container {
background-color: #f8f9fa;
padding: 2rem;
border-radius: 10px;
margin-bottom: 2rem;
}
.stButton>button {
background-color: #007bff;
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 5px;
transition: all 0.3s ease;
}
.stButton>button:hover {
background-color: #0056b3;
transform: translateY(-2px);
}
.element-container {opacity: 1 !important;}
.pdf-history-item {
background-color: #f8f9fa;
padding: 1rem;
border-radius: 8px;
margin: 0.5rem 0;
border: 1px solid #dee2e6;
}
.metric-card {
background-color: #f8f9fa;
padding: 1rem;
border-radius: 8px;
border: 1px solid #dee2e6;
margin: 0.5rem 0;
}
</style>
""", unsafe_allow_html=True)
class PDFGenerator:
@staticmethod
def create_pdf(data: Dict[str, Any]) -> bytes:
buffer = io.BytesIO()
doc = SimpleDocTemplate(buffer, pagesize=letter)
styles = getSampleStyleSheet()
elements = []
# Title
title_style = ParagraphStyle(
'CustomTitle',
parent=styles['Heading1'],
fontSize=24,
spaceAfter=30
)
elements.append(Paragraph("Medical Document Report", title_style))
elements.append(Spacer(1, 20))
# Patient Information
elements.append(Paragraph("Patient Information", styles['Heading2']))
patient_info = data.get('patient_info', {})
patient_data = [
["Name:", patient_info.get('name', 'N/A')],
["Age:", patient_info.get('age', 'N/A')],
["Gender:", patient_info.get('gender', 'N/A')]
]
patient_table = Table(patient_data, colWidths=[100, 300])
patient_table.setStyle(TableStyle([
('GRID', (0, 0), (-1, -1), 1, colors.black),
('PADDING', (0, 0), (-1, -1), 6),
]))
elements.append(patient_table)
elements.append(Spacer(1, 20))
# Symptoms
if data.get('symptoms'):
elements.append(Paragraph("Symptoms", styles['Heading2']))
symptoms_text = "\n".join([f"- {symptom}" for symptom in data['symptoms']])
elements.append(Paragraph(symptoms_text, styles['BodyText']))
elements.append(Spacer(1, 20))
# Vital Signs
if data.get('vital_signs'):
elements.append(Paragraph("Vital Signs", styles['Heading2']))
vital_signs_data = [["Type", "Value"]] + [[vs['type'], vs['value']] for vs in data['vital_signs']]
vital_signs_table = Table(vital_signs_data, colWidths=[200, 200])
vital_signs_table.setStyle(TableStyle([
('GRID', (0, 0), (-1, -1), 1, colors.black),
('PADDING', (0, 0), (-1, -1), 6),
]))
elements.append(vital_signs_table)
elements.append(Spacer(1, 20))
# Medications
if data.get('medications'):
elements.append(Paragraph("Medications", styles['Heading2']))
meds_data = [["Name", "Dosage", "Instructions"]] + [
[med['name'], med['dosage'], med['instructions']] for med in data['medications']
]
meds_table = Table(meds_data, colWidths=[150, 100, 250])
meds_table.setStyle(TableStyle([
('GRID', (0, 0), (-1, -1), 1, colors.black),
('PADDING', (0, 0), (-1, -1), 6),
]))
elements.append(meds_table)
elements.append(Spacer(1, 20))
doc.build(elements)
return buffer.getvalue()
class ImageProcessor:
@staticmethod
def validate_file(uploaded_file) -> tuple[bool, str]:
try:
if uploaded_file.size > Config.MAX_FILE_SIZE:
return False, f"File size exceeds {Config.MAX_FILE_SIZE // (1024*1024)}MB limit"
if uploaded_file.type not in Config.ALLOWED_MIME_TYPES:
return False, "Unsupported file type. Please upload JPEG, PNG, or PDF."
return True, "File validation successful"
except Exception as e:
logger.error(f"File validation error: {str(e)}")
return False, f"File validation failed: {str(e)}"
@staticmethod
def preprocess_image(image: Image.Image) -> Image.Image:
try:
if image.mode != 'RGB':
image = image.convert('RGB')
if image.size[0] > Config.MAX_IMAGE_SIZE[0] or image.size[1] > Config.MAX_IMAGE_SIZE[1]:
image.thumbnail(Config.MAX_IMAGE_SIZE, Image.Resampling.LANCZOS)
return image
except Exception as e:
logger.error(f"Image preprocessing error: {str(e)}")
raise ImageProcessingError(f"Failed to preprocess image: {str(e)}")
class DocumentProcessor:
def __init__(self):
self.image_processor = ImageProcessor()
def process_document(self, uploaded_file) -> Dict[str, Any]:
try:
if uploaded_file.type.startswith("image/"):
# Process image
image = Image.open(uploaded_file)
processed_image = self.image_processor.preprocess_image(image)
image_base64 = self.encode_image(processed_image)
extracted_text = self.extract_text(image_base64)
elif uploaded_file.type == "application/pdf":
# Process PDF
extracted_text = self.extract_text_from_pdf(uploaded_file)
else:
raise ValueError("Unsupported file type.")
results = {
"document_type": self.classify_document(extracted_text),
"extracted_text": extracted_text,
"structured_data": None
}
if results["extracted_text"]:
results["structured_data"] = self.extract_structured_data(
results["extracted_text"]
)
return results
except Exception as e:
logger.error(f"Document processing error: {str(e)}")
raise
@staticmethod
def encode_image(image: Image.Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG", quality=95)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
@staticmethod
def extract_text_from_pdf(uploaded_file) -> str:
try:
pdf_bytes = uploaded_file.read()
pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
text = ""
for page_num in range(len(pdf_document)):
page = pdf_document.load_page(page_num)
text += page.get_text()
return text
except Exception as e:
logger.error(f"PDF processing error: {str(e)}")
raise PDFProcessingError(f"Failed to process PDF: {str(e)}")
def classify_document(self, text: str) -> str:
prompt = f"""
Analyze this medical document and classify it into one of the following categories:
- Lab Report
- Patient Chart
- Prescription
- Imaging Report
- Medical Certificate
- Other (specify)
Provide only the category name.
Document Text:
{text}
"""
response = GeminiAPI.call_api(prompt)
return response["candidates"][0]["content"]["parts"][0]["text"].strip()
def extract_text(self, image_base64: str) -> str:
prompt = """
Extract all visible text from this medical document.
Include:
- Headers and titles
- Patient information
- Medical data and values
- Notes and annotations
- Dates and timestamps
Format the output in a clear, structured manner.
"""
response = GeminiAPI.call_api(prompt, image_base64)
return response["candidates"][0]["content"]["parts"][0]["text"].strip()
def extract_structured_data(self, text: str) -> Dict[str, Any]:
prompt = f"""
Analyze this medical text and return a valid JSON object with the following structure:
{{
"patient_info": {{
"name": "string",
"age": "string",
"gender": "string"
}},
"symptoms": ["string"],
"visits": [
{{
"date": "string",
"reason": "string",
"notes": "string"
}}
],
"vital_signs": [
{{
"type": "string",
"value": "string"
}}
],
"medications": [
{{
"name": "string",
"dosage": "string",
"instructions": "string"
}}
]
}}
Text to analyze:
{text}
"""
response = GeminiAPI.call_api(prompt)
structured_data = self.parse_json_response(response)
# Predict gender if not mentioned
if not structured_data['patient_info'].get('gender'):
structured_data['patient_info']['gender'] = self.predict_gender(
structured_data['patient_info'].get('name', '')
)
# Correct medicine names
structured_data['medications'] = [
self.correct_medicine_name(med) for med in structured_data.get('medications', [])
]
# Improve symptoms extraction
structured_data['symptoms'] = self.extract_symptoms(text)
return structured_data
@staticmethod
def predict_gender(name: str) -> str:
"""Predict gender based on the patient's name."""
prompt = f"""
Based on the name '{name}', predict the gender. Return only 'Male' or 'Female'.
"""
response = GeminiAPI.call_api(prompt)
return response["candidates"][0]["content"]["parts"][0]["text"].strip()
@staticmethod
def correct_medicine_name(medication: Dict[str, Any]) -> Dict[str, Any]:
"""Correct the medicine name using a standardized approach."""
prompt = f"""
Correct the following medicine name to its standard form:
{medication['name']}
Return only the corrected name.
"""
response = GeminiAPI.call_api(prompt)
medication['name'] = response["candidates"][0]["content"]["parts"][0]["text"].strip()
return medication
@staticmethod
def extract_symptoms(text: str) -> list[str]:
"""Extract symptoms from the text."""
prompt = f"""
Extract all symptoms mentioned in the following medical text. Return only a list of symptoms:
{text}
"""
response = GeminiAPI.call_api(prompt)
symptoms = response["candidates"][0]["content"]["parts"][0]["text"].strip().split("\n")
return [symptom.strip() for symptom in symptoms if symptom.strip()]
@staticmethod
def parse_json_response(response: Dict[str, Any]) -> Dict[str, Any]:
try:
response_text = response["candidates"][0]["content"]["parts"][0]["text"].strip()
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
return json.loads(json_match.group())
raise ValueError("No JSON object found in response")
except Exception as e:
logger.error(f"JSON parsing error: {str(e)}")
raise
class EHRViewer:
@staticmethod
def display_ehr(data: Dict[str, Any]):
st.markdown("## 📊 Electronic Health Record")
with st.container():
st.markdown("### 👤 Patient Information")
cols = st.columns(3)
patient_info = data.get('patient_info', {})
cols[0].metric("Name", patient_info.get('name', 'N/A'))
cols[1].metric("Age", patient_info.get('age', 'N/A'))
cols[2].metric("Gender", patient_info.get('gender', 'N/A'))
if data.get('symptoms'):
st.markdown("### 🤒 Symptoms")
symptoms_text = "\n".join([f"- {symptom}" for symptom in data['symptoms']])
st.markdown(symptoms_text)
if data.get('vital_signs'):
st.markdown("### 🫀 Vital Signs")
vital_signs_df = pd.DataFrame(data['vital_signs'])
st.dataframe(vital_signs_df, use_container_width=True)
if data.get('medications'):
st.markdown("### 💊 Medications")
med_df = pd.DataFrame(data['medications'])
st.dataframe(med_df, use_container_width=True)
class GeminiAPI:
@staticmethod
def call_api(prompt: str, image_base64: Optional[str] = None) -> Dict[str, Any]:
for attempt in range(Config.MAX_RETRIES):
try:
headers = {"Content-Type": "application/json"}
parts = [{"text": prompt}]
if image_base64:
parts.append({
"inline_data": {
"mime_type": "image/jpeg",
"data": image_base64
}
})
payload = {"contents": [{"parts": parts}]}
response = requests.post(
f"{Config.GEMINI_URL}?key={Config.GEMINI_API_KEY}",
headers=headers,
json=payload,
timeout=Config.TIMEOUT
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if attempt == Config.MAX_RETRIES - 1:
logger.error(f"API call failed after {Config.MAX_RETRIES} attempts: {str(e)}")
raise APIError(f"API call failed: {str(e)}")
time.sleep(2 ** attempt)
def main():
init_session_state()
setup_page()
st.title("🏥 Advanced Medical Document Processor")
st.markdown("Upload medical documents (images or PDFs) for automated processing and analysis.")
# Sidebar
with st.sidebar:
st.header("📋 Processing History")
if st.session_state.pdf_history:
for idx, pdf_record in enumerate(st.session_state.pdf_history):
with st.expander(f"Document {idx + 1}: {pdf_record['timestamp']}"):
st.download_button(
"📄 Download PDF",
pdf_record['data'],
file_name=pdf_record['filename'],
mime="application/pdf",
key=f"sidebar_{pdf_record['timestamp']}"
)
else:
st.info("No documents processed yet")
# Main content
uploaded_file = st.file_uploader(
"Choose a medical document",
type=['png', 'jpg', 'jpeg', 'pdf'],
help="Upload a clear image or PDF of a medical document (max 5MB)"
)
if uploaded_file:
try:
# Validate file
is_valid, message = ImageProcessor.validate_file(uploaded_file)
if not is_valid:
st.error(message)
return
# Display file
if uploaded_file.type.startswith("image/"):
image = Image.open(uploaded_file)
col1, col2 = st.columns([1, 2])
with col1:
st.image(image, caption="Uploaded Document", use_column_width=True)
elif uploaded_file.type == "application/pdf":
st.info("PDF file uploaded. Processing...")
# Process document
if st.button("🔍 Process Document"):
with st.spinner("Processing document..."):
processor = DocumentProcessor()
results = processor.process_document(uploaded_file)
# Generate PDF
pdf_bytes = PDFGenerator.create_pdf(results['structured_data'])
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
pdf_filename = f"medical_report_{timestamp}.pdf"
# Store in session state
st.session_state.current_document = {
'timestamp': timestamp,
'results': results
}
st.session_state.processing_history.append(
st.session_state.current_document
)
st.session_state.pdf_history.append({
'timestamp': timestamp,
'filename': pdf_filename,
'data': pdf_bytes
})
# Display results
with col2 if uploaded_file.type.startswith("image/") else st:
st.success("Document processed successfully!")
st.markdown(f"**Document Type:** {results['document_type']}")
with st.expander("View Extracted Text"):
st.text_area(
"Raw Text",
results['extracted_text'],
height=200
)
# Display EHR View
if results['structured_data']:
EHRViewer.display_ehr(results['structured_data'])
# Download options
st.markdown("### 📥 Download Options")
col1, col2 = st.columns(2)
with col1:
json_str = json.dumps(results['structured_data'], indent=2)
st.download_button(
"⬇️ Download JSON",
json_str,
file_name=f"medical_data_{timestamp}.json",
mime="application/json"
)
with col2:
st.download_button(
"📄 Download PDF Report",
pdf_bytes,
file_name=pdf_filename,
mime="application/pdf"
)
# Display PDF History
st.markdown("### 📚 PDF History")
if st.session_state.pdf_history:
for pdf_record in st.session_state.pdf_history:
col1, col2 = st.columns([3, 1])
with col1:
st.write(f"Report from {pdf_record['timestamp']}")
with col2:
st.download_button(
"📄 View PDF",
pdf_record['data'],
file_name=pdf_record['filename'],
mime="application/pdf",
key=f"history_{pdf_record['timestamp']}"
)
else:
st.info("No PDF history available")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
logger.exception("Error in main processing loop")
if __name__ == "__main__":
try:
main()
except Exception as e:
st.error("An unexpected error occurred. Please try again later.")
logger.exception("Unhandled exception in main application")