MedDocDigitizer / app.py
rajsecrets0's picture
Update app.py
5a0faa3 verified
raw
history blame
21.9 kB
import streamlit as st
import requests
import json
from PIL import Image
import base64
from io import BytesIO
import pandas as pd
from datetime import datetime
import time
import logging
import os
from typing import Dict, Any, Optional
import re
from reportlab.lib import colors
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
import io
import pytesseract # Tesseract OCR
from dotenv import load_dotenv # For .env file
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Configuration and Constants
class Config:
GEMINI_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash-exp:generateContent"
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") # Load from .env
MAX_RETRIES = 3
TIMEOUT = 30
MAX_IMAGE_SIZE = (1600, 1600)
ALLOWED_MIME_TYPES = ["image/jpeg", "image/png"]
MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
# Custom Exceptions
class APIError(Exception):
pass
class ImageProcessingError(Exception):
pass
# Initialize session state
def init_session_state():
if 'processing_history' not in st.session_state:
st.session_state.processing_history = []
if 'current_document' not in st.session_state:
st.session_state.current_document = None
if 'pdf_history' not in st.session_state:
st.session_state.pdf_history = []
# Page setup and styling
def setup_page():
st.set_page_config(
page_title="Medical Document Processor",
page_icon="🏥",
layout="wide",
initial_sidebar_state="expanded"
)
st.markdown("""
<style>
.main {padding: 2rem; max-width: 1200px; margin: 0 auto;}
.stCard {
background-color: white;
padding: 2rem;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
margin: 1rem 0;
}
.header-container {
background-color: #f8f9fa;
padding: 2rem;
border-radius: 10px;
margin-bottom: 2rem;
}
.stButton>button {
background-color: #007bff;
color: white;
border: none;
padding: 0.5rem 1rem;
border-radius: 5px;
transition: all 0.3s ease;
}
.stButton>button:hover {
background-color: #0056b3;
transform: translateY(-2px);
}
.element-container {opacity: 1 !important;}
.pdf-history-item {
background-color: #f8f9fa;
padding: 1rem;
border-radius: 8px;
margin: 0.5rem 0;
border: 1px solid #dee2e6;
}
.metric-card {
background-color: #f8f9fa;
padding: 1rem;
border-radius: 8px;
border: 1px solid #dee2e6;
margin: 0.5rem 0;
}
</style>
""", unsafe_allow_html=True)
class PDFGenerator:
@staticmethod
def create_pdf(data: Dict[str, Any]) -> bytes:
buffer = io.BytesIO()
doc = SimpleDocTemplate(buffer, pagesize=letter)
styles = getSampleStyleSheet()
elements = []
# Title
title_style = ParagraphStyle(
'CustomTitle',
parent=styles['Heading1'],
fontSize=24,
spaceAfter=30
)
elements.append(Paragraph("Medical Document Report", title_style))
elements.append(Spacer(1, 20))
# Patient Information
elements.append(Paragraph("Patient Information", styles['Heading2']))
patient_info = data.get('patient_info', {})
patient_data = [
["Name:", patient_info.get('name', 'N/A')],
["Age:", patient_info.get('age', 'N/A')],
["Gender:", patient_info.get('gender', 'N/A')]
]
patient_table = Table(patient_data, colWidths=[100, 300])
patient_table.setStyle(TableStyle([
('GRID', (0, 0), (-1, -1), 1, colors.black),
('PADDING', (0, 0), (-1, -1), 6),
]))
elements.append(patient_table)
elements.append(Spacer(1, 20))
# Symptoms
if data.get('symptoms'):
elements.append(Paragraph("Symptoms", styles['Heading2']))
symptoms_text = "\n".join([f"- {symptom}" for symptom in data['symptoms']])
elements.append(Paragraph(symptoms_text, styles['BodyText']))
elements.append(Spacer(1, 20))
# Vital Signs
if data.get('vital_signs'):
elements.append(Paragraph("Vital Signs", styles['Heading2']))
vital_signs_data = [["Type", "Value"]] + [[vs['type'], vs['value']] for vs in data['vital_signs']]
vital_signs_table = Table(vital_signs_data, colWidths=[200, 200])
vital_signs_table.setStyle(TableStyle([
('GRID', (0, 0), (-1, -1), 1, colors.black),
('PADDING', (0, 0), (-1, -1), 6),
]))
elements.append(vital_signs_table)
elements.append(Spacer(1, 20))
# Medications
if data.get('medications'):
elements.append(Paragraph("Medications", styles['Heading2']))
meds_data = [["Name", "Dosage", "Instructions"]] + [
[med['name'], med['dosage'], med['instructions']] for med in data['medications']
]
meds_table = Table(meds_data, colWidths=[150, 100, 250])
meds_table.setStyle(TableStyle([
('GRID', (0, 0), (-1, -1), 1, colors.black),
('PADDING', (0, 0), (-1, -1), 6),
]))
elements.append(meds_table)
elements.append(Spacer(1, 20))
doc.build(elements)
return buffer.getvalue()
class ImageProcessor:
@staticmethod
def validate_image(uploaded_file) -> tuple[bool, str]:
try:
if uploaded_file.size > Config.MAX_FILE_SIZE:
return False, f"File size exceeds {Config.MAX_FILE_SIZE // (1024*1024)}MB limit"
image = Image.open(uploaded_file)
if image.format.upper() not in ['JPEG', 'PNG']:
return False, "Unsupported image format. Please upload JPEG or PNG"
return True, "Image validation successful"
except Exception as e:
logger.error(f"Image validation error: {str(e)}")
return False, f"Image validation failed: {str(e)}"
@staticmethod
def preprocess_image(image: Image.Image) -> Image.Image:
try:
if image.mode != 'RGB':
image = image.convert('RGB')
if image.size[0] > Config.MAX_IMAGE_SIZE[0] or image.size[1] > Config.MAX_IMAGE_SIZE[1]:
image.thumbnail(Config.MAX_IMAGE_SIZE, Image.Resampling.LANCZOS)
return image
except Exception as e:
logger.error(f"Image preprocessing error: {str(e)}")
raise ImageProcessingError(f"Failed to preprocess image: {str(e)}")
class DocumentProcessor:
def __init__(self):
self.image_processor = ImageProcessor()
def process_document(self, image: Image.Image) -> Dict[str, Any]:
try:
processed_image = self.image_processor.preprocess_image(image)
# Extract text using Tesseract OCR
tesseract_text = pytesseract.image_to_string(processed_image)
# Extract text using Gemini API
image_base64 = self.encode_image(processed_image)
gemini_text = self.extract_text_with_gemini(image_base64)
# Combine results from Tesseract and Gemini
combined_text = self.combine_text_results(tesseract_text, gemini_text)
results = {
"document_type": self.classify_document(combined_text),
"extracted_text": combined_text,
"structured_data": None
}
if results["extracted_text"]:
results["structured_data"] = self.extract_structured_data(
results["extracted_text"]
)
return results
except Exception as e:
logger.error(f"Document processing error: {str(e)}")
raise
@staticmethod
def encode_image(image: Image.Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG", quality=95)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
@staticmethod
def extract_text_with_gemini(image_base64: str) -> str:
prompt = """
Extract all visible text from this medical document.
Include:
- Headers and titles
- Patient information
- Medical data and values
- Notes and annotations
- Dates and timestamps
Format the output in a clear, structured manner.
"""
response = GeminiAPI.call_api(prompt, image_base64)
return response["candidates"][0]["content"]["parts"][0]["text"].strip()
@staticmethod
def combine_text_results(tesseract_text: str, gemini_text: str) -> str:
# Combine results, prioritizing Gemini's output but adding Tesseract's output for completeness
combined_text = f"Gemini Extracted Text:\n{gemini_text}\n\nTesseract Extracted Text:\n{tesseract_text}"
return combined_text
def classify_document(self, text: str) -> str:
prompt = f"""
Analyze this medical document and classify it into one of the following categories:
- Lab Report
- Patient Chart
- Prescription
- Imaging Report
- Medical Certificate
- Other (specify)
Provide only the category name.
Document Text:
{text}
"""
response = GeminiAPI.call_api(prompt)
return response["candidates"][0]["content"]["parts"][0]["text"].strip()
def extract_structured_data(self, text: str) -> Dict[str, Any]:
prompt = f"""
Analyze this medical text and return a valid JSON object with the following structure:
{{
"patient_info": {{
"name": "string",
"age": "string",
"gender": "string"
}},
"symptoms": ["string"],
"visits": [
{{
"date": "string",
"reason": "string",
"notes": "string"
}}
],
"vital_signs": [
{{
"type": "string",
"value": "string"
}}
],
"medications": [
{{
"name": "string",
"dosage": "string",
"instructions": "string"
}}
]
}}
Text to analyze:
{text}
"""
response = GeminiAPI.call_api(prompt)
structured_data = self.parse_json_response(response)
# Predict gender if not mentioned
if not structured_data['patient_info'].get('gender'):
structured_data['patient_info']['gender'] = self.predict_gender(
structured_data['patient_info'].get('name', '')
)
# Correct medicine names
structured_data['medications'] = [
self.correct_medicine_name(med) for med in structured_data.get('medications', [])
]
# Improve symptoms extraction
structured_data['symptoms'] = self.extract_symptoms(text)
return structured_data
@staticmethod
def predict_gender(name: str) -> str:
"""Predict gender based on the patient's name."""
prompt = f"""
Based on the name '{name}', predict the gender. Return only 'Male' or 'Female'.
"""
response = GeminiAPI.call_api(prompt)
return response["candidates"][0]["content"]["parts"][0]["text"].strip()
@staticmethod
def correct_medicine_name(medication: Dict[str, Any]) -> Dict[str, Any]:
"""Correct the medicine name using a standardized approach."""
prompt = f"""
Correct the following medicine name to its standard form:
{medication['name']}
Return only the corrected name.
"""
response = GeminiAPI.call_api(prompt)
medication['name'] = response["candidates"][0]["content"]["parts"][0]["text"].strip()
return medication
@staticmethod
def extract_symptoms(text: str) -> list[str]:
"""Extract symptoms from the text."""
prompt = f"""
Extract all symptoms mentioned in the following medical text. Return only a list of symptoms:
{text}
"""
response = GeminiAPI.call_api(prompt)
symptoms = response["candidates"][0]["content"]["parts"][0]["text"].strip().split("\n")
return [symptom.strip() for symptom in symptoms if symptom.strip()]
@staticmethod
def parse_json_response(response: Dict[str, Any]) -> Dict[str, Any]:
try:
response_text = response["candidates"][0]["content"]["parts"][0]["text"].strip()
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
return json.loads(json_match.group())
raise ValueError("No JSON object found in response")
except Exception as e:
logger.error(f"JSON parsing error: {str(e)}")
raise
class EHRViewer:
@staticmethod
def display_ehr(data: Dict[str, Any]):
st.markdown("## 📊 Electronic Health Record")
with st.container():
st.markdown("### 👤 Patient Information")
cols = st.columns(3)
patient_info = data.get('patient_info', {})
cols[0].metric("Name", patient_info.get('name', 'N/A'))
cols[1].metric("Age", patient_info.get('age', 'N/A'))
cols[2].metric("Gender", patient_info.get('gender', 'N/A'))
if data.get('symptoms'):
st.markdown("### 🤒 Symptoms")
symptoms_text = "\n".join([f"- {symptom}" for symptom in data['symptoms']])
st.markdown(symptoms_text)
if data.get('vital_signs'):
st.markdown("### 🫀 Vital Signs")
vital_signs_df = pd.DataFrame(data['vital_signs'])
st.dataframe(vital_signs_df, use_container_width=True)
if data.get('medications'):
st.markdown("### 💊 Medications")
med_df = pd.DataFrame(data['medications'])
st.dataframe(med_df, use_container_width=True)
class GeminiAPI:
@staticmethod
def call_api(prompt: str, image_base64: Optional[str] = None) -> Dict[str, Any]:
for attempt in range(Config.MAX_RETRIES):
try:
headers = {"Content-Type": "application/json"}
parts = [{"text": prompt}]
if image_base64:
parts.append({
"inline_data": {
"mime_type": "image/jpeg",
"data": image_base64
}
})
payload = {"contents": [{"parts": parts}]}
response = requests.post(
f"{Config.GEMINI_URL}?key={Config.GEMINI_API_KEY}",
headers=headers,
json=payload,
timeout=Config.TIMEOUT
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if attempt == Config.MAX_RETRIES - 1:
logger.error(f"API call failed after {Config.MAX_RETRIES} attempts: {str(e)}")
raise APIError(f"API call failed: {str(e)}")
time.sleep(2 ** attempt)
def main():
init_session_state()
setup_page()
st.title("🏥 Advanced Medical Document Processor")
st.markdown("Upload medical documents for automated processing and analysis.")
# Sidebar
with st.sidebar:
st.header("📋 Processing History")
if st.session_state.pdf_history:
for idx, pdf_record in enumerate(st.session_state.pdf_history):
with st.expander(f"Document {idx + 1}: {pdf_record['timestamp']}"):
st.download_button(
"📄 Download PDF",
pdf_record['data'],
file_name=pdf_record['filename'],
mime="application/pdf",
key=f"sidebar_{pdf_record['timestamp']}"
)
else:
st.info("No documents processed yet")
# Main content
uploaded_file = st.file_uploader(
"Choose a medical document",
type=['png', 'jpg', 'jpeg'],
help="Upload a clear image of a medical document (max 5MB)"
)
if uploaded_file:
try:
# Validate image
is_valid, message = ImageProcessor.validate_image(uploaded_file)
if not is_valid:
st.error(message)
return
# Display image
image = Image.open(uploaded_file)
col1, col2 = st.columns([1, 2])
with col1:
st.image(image, caption="Uploaded Document", use_column_width=True)
# Process document
if st.button("🔍 Process Document"):
with st.spinner("Processing document..."):
processor = DocumentProcessor()
results = processor.process_document(image)
# Generate PDF
pdf_bytes = PDFGenerator.create_pdf(results['structured_data'])
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
pdf_filename = f"medical_report_{timestamp}.pdf"
# Store in session state
st.session_state.current_document = {
'timestamp': timestamp,
'results': results
}
st.session_state.processing_history.append(
st.session_state.current_document
)
st.session_state.pdf_history.append({
'timestamp': timestamp,
'filename': pdf_filename,
'data': pdf_bytes
})
# Display results
with col2:
st.success("Document processed successfully!")
st.markdown(f"**Document Type:** {results['document_type']}")
with st.expander("View Extracted Text"):
st.text_area(
"Raw Text",
results['extracted_text'],
height=200
)
# Display EHR View
if results['structured_data']:
EHRViewer.display_ehr(results['structured_data'])
# Download options
st.markdown("### 📥 Download Options")
col1, col2 = st.columns(2)
with col1:
json_str = json.dumps(results['structured_data'], indent=2)
st.download_button(
"⬇️ Download JSON",
json_str,
file_name=f"medical_data_{timestamp}.json",
mime="application/json"
)
with col2:
st.download_button(
"📄 Download PDF Report",
pdf_bytes,
file_name=pdf_filename,
mime="application/pdf"
)
# Display PDF History
st.markdown("### 📚 PDF History")
if st.session_state.pdf_history:
for pdf_record in st.session_state.pdf_history:
col1, col2 = st.columns([3, 1])
with col1:
st.write(f"Report from {pdf_record['timestamp']}")
with col2:
st.download_button(
"📄 View PDF",
pdf_record['data'],
file_name=pdf_record['filename'],
mime="application/pdf",
key=f"history_{pdf_record['timestamp']}"
)
else:
st.info("No PDF history available")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
logger.exception("Error in main processing loop")
if __name__ == "__main__":
try:
main()
except Exception as e:
st.error("An unexpected error occurred. Please try again later.")
logger.exception("Unhandled exception in main application")