Amarthya7's picture
Upload 21 files
86a74e6 verified
raw
history blame
15.7 kB
import base64
import io
import logging
import cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
# Set up logging
logger = logging.getLogger(__name__)
def plot_image_prediction(image, predictions, title=None, figsize=(10, 8)):
"""
Plot an image with its predictions.
Args:
image (PIL.Image or str): Image or path to image
predictions (list): List of (label, probability) tuples
title (str, optional): Plot title
figsize (tuple): Figure size
Returns:
matplotlib.figure.Figure: The figure object
"""
try:
# Load image if path is provided
if isinstance(image, str):
img = Image.open(image)
else:
img = image
# Create figure
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
# Plot image
ax1.imshow(img)
ax1.set_title("X-ray Image")
ax1.axis("off")
# Plot predictions
if predictions:
# Sort predictions by probability
sorted_pred = sorted(predictions, key=lambda x: x[1], reverse=True)
# Get top 5 predictions
top_n = min(5, len(sorted_pred))
labels = [pred[0] for pred in sorted_pred[:top_n]]
probs = [pred[1] for pred in sorted_pred[:top_n]]
# Plot horizontal bar chart
y_pos = np.arange(top_n)
ax2.barh(y_pos, probs, align="center")
ax2.set_yticks(y_pos)
ax2.set_yticklabels(labels)
ax2.set_xlabel("Probability")
ax2.set_title("Top Predictions")
ax2.set_xlim(0, 1)
# Annotate probabilities
for i, prob in enumerate(probs):
ax2.text(prob + 0.02, i, f"{prob:.1%}", va="center")
# Set overall title
if title:
fig.suptitle(title, fontsize=16)
fig.tight_layout()
return fig
except Exception as e:
logger.error(f"Error plotting image prediction: {e}")
# Create empty figure if error occurs
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
return fig
def create_heatmap_overlay(image, heatmap, alpha=0.4):
"""
Create a heatmap overlay on an X-ray image to highlight areas of interest.
Args:
image (PIL.Image or str): Image or path to image
heatmap (numpy.ndarray): Heatmap array
alpha (float): Transparency of the overlay
Returns:
PIL.Image: Image with heatmap overlay
"""
try:
# Load image if path is provided
if isinstance(image, str):
img = cv2.imread(image)
if img is None:
raise ValueError(f"Could not load image: {image}")
elif isinstance(image, Image.Image):
img = np.array(image)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
else:
img = image
# Ensure image is in BGR format for OpenCV
if len(img.shape) == 2: # Grayscale
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
# Resize heatmap to match image dimensions
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
# Normalize heatmap (0-1)
heatmap = np.maximum(heatmap, 0)
heatmap = np.minimum(heatmap / np.max(heatmap), 1)
# Apply colormap (jet) to heatmap
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
# Create overlay
overlay = cv2.addWeighted(img, 1 - alpha, heatmap, alpha, 0)
# Convert back to PIL image
overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
overlay_img = Image.fromarray(overlay)
return overlay_img
except Exception as e:
logger.error(f"Error creating heatmap overlay: {e}")
# Return original image if error occurs
if isinstance(image, str):
return Image.open(image)
elif isinstance(image, Image.Image):
return image
else:
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def plot_report_entities(text, entities, figsize=(12, 8)):
"""
Visualize entities extracted from a medical report.
Args:
text (str): Report text
entities (dict): Dictionary of entities by category
figsize (tuple): Figure size
Returns:
matplotlib.figure.Figure: The figure object
"""
try:
fig, ax = plt.subplots(figsize=figsize)
ax.axis("off")
# Set background color
fig.patch.set_facecolor("#f8f9fa")
ax.set_facecolor("#f8f9fa")
# Title
ax.text(
0.5,
0.98,
"Medical Report Analysis",
ha="center",
va="top",
fontsize=18,
fontweight="bold",
color="#2c3e50",
)
# Display entity counts
y_pos = 0.9
ax.text(
0.05,
y_pos,
"Extracted Entities:",
fontsize=14,
fontweight="bold",
color="#2c3e50",
)
y_pos -= 0.05
# Define colors for different entity categories
category_colors = {
"problem": "#e74c3c", # Red
"test": "#3498db", # Blue
"treatment": "#2ecc71", # Green
"anatomy": "#9b59b6", # Purple
}
# Display entities by category
for category, items in entities.items():
if items:
y_pos -= 0.05
ax.text(
0.1,
y_pos,
f"{category.capitalize()}:",
fontsize=12,
fontweight="bold",
)
y_pos -= 0.05
ax.text(
0.15,
y_pos,
", ".join(items),
wrap=True,
fontsize=11,
color=category_colors.get(category, "black"),
)
# Add the report text with highlighted entities
y_pos -= 0.1
ax.text(
0.05,
y_pos,
"Report Text (with highlighted entities):",
fontsize=14,
fontweight="bold",
color="#2c3e50",
)
y_pos -= 0.05
# Get all entities to highlight
all_entities = []
for category, items in entities.items():
for item in items:
all_entities.append((item, category))
# Sort entities by length (longest first to avoid overlap issues)
all_entities.sort(key=lambda x: len(x[0]), reverse=True)
# Highlight entities in text
highlighted_text = text
for entity, category in all_entities:
# Escape regex special characters
entity_escaped = (
entity.replace("(", r"\(")
.replace(")", r"\)")
.replace("[", r"\[")
.replace("]", r"\]")
)
# Find entity in text (word boundary)
pattern = r"\b" + entity_escaped + r"\b"
color_code = category_colors.get(category, "black")
replacement = f"\\textcolor{{{color_code}}}{{{entity}}}"
highlighted_text = highlighted_text.replace(entity, replacement)
# Display highlighted text
ax.text(0.05, y_pos, highlighted_text, va="top", fontsize=10, wrap=True)
fig.tight_layout(rect=[0, 0.03, 1, 0.97])
return fig
except Exception as e:
logger.error(f"Error plotting report entities: {e}")
# Create empty figure if error occurs
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
return fig
def plot_multimodal_results(
fused_results, image=None, report_text=None, figsize=(12, 10)
):
"""
Visualize the results of multimodal analysis.
Args:
fused_results (dict): Results from multimodal fusion
image (PIL.Image or str, optional): Image or path to image
report_text (str, optional): Report text
figsize (tuple): Figure size
Returns:
matplotlib.figure.Figure: The figure object
"""
try:
# Create figure with a grid layout
fig = plt.figure(figsize=figsize)
gs = fig.add_gridspec(2, 2)
# Add title
fig.suptitle(
"Multimodal Medical Analysis Results",
fontsize=18,
fontweight="bold",
y=0.98,
)
# 1. Overview panel (top left)
ax_overview = fig.add_subplot(gs[0, 0])
ax_overview.axis("off")
# Get severity info
severity = fused_results.get("severity", {})
severity_level = severity.get("level", "Unknown")
severity_score = severity.get("score", 0)
# Get primary finding
primary_finding = fused_results.get("primary_finding", "Unknown")
# Get agreement score
agreement = fused_results.get("agreement_score", 0)
# Create overview text
overview_text = [
"ANALYSIS OVERVIEW",
f"Primary Finding: {primary_finding}",
f"Severity Level: {severity_level} ({severity_score}/4)",
f"Agreement Score: {agreement:.0%}",
]
# Define severity colors
severity_colors = {
"Normal": "#2ecc71", # Green
"Mild": "#3498db", # Blue
"Moderate": "#f39c12", # Orange
"Severe": "#e74c3c", # Red
"Critical": "#c0392b", # Dark Red
}
# Add overview text to the panel
y_pos = 0.9
ax_overview.text(
0.5,
y_pos,
overview_text[0],
fontsize=14,
fontweight="bold",
ha="center",
va="center",
)
y_pos -= 0.15
ax_overview.text(
0.1, y_pos, overview_text[1], fontsize=12, ha="left", va="center"
)
y_pos -= 0.1
# Severity with color
severity_color = severity_colors.get(severity_level, "black")
ax_overview.text(
0.1, y_pos, "Severity Level:", fontsize=12, ha="left", va="center"
)
ax_overview.text(
0.4,
y_pos,
severity_level,
fontsize=12,
color=severity_color,
fontweight="bold",
ha="left",
va="center",
)
ax_overview.text(
0.6, y_pos, f"({severity_score}/4)", fontsize=10, ha="left", va="center"
)
y_pos -= 0.1
# Agreement score with color
agreement_color = (
"#2ecc71"
if agreement > 0.7
else "#f39c12"
if agreement > 0.4
else "#e74c3c"
)
ax_overview.text(
0.1, y_pos, "Agreement Score:", fontsize=12, ha="left", va="center"
)
ax_overview.text(
0.4,
y_pos,
f"{agreement:.0%}",
fontsize=12,
color=agreement_color,
fontweight="bold",
ha="left",
va="center",
)
# 2. Findings panel (top right)
ax_findings = fig.add_subplot(gs[0, 1])
ax_findings.axis("off")
# Get findings
findings = fused_results.get("findings", [])
# Add findings to the panel
y_pos = 0.9
ax_findings.text(
0.5,
y_pos,
"KEY FINDINGS",
fontsize=14,
fontweight="bold",
ha="center",
va="center",
)
y_pos -= 0.1
if findings:
for i, finding in enumerate(findings[:5]): # Limit to 5 findings
ax_findings.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center")
ax_findings.text(
0.1, y_pos, finding, fontsize=11, ha="left", va="center", wrap=True
)
y_pos -= 0.15
else:
ax_findings.text(
0.1,
y_pos,
"No specific findings detailed.",
fontsize=11,
ha="left",
va="center",
)
# 3. Image panel (bottom left)
ax_image = fig.add_subplot(gs[1, 0])
if image is not None:
# Load image if path is provided
if isinstance(image, str):
img = Image.open(image)
else:
img = image
# Display image
ax_image.imshow(img)
ax_image.set_title("X-ray Image", fontsize=12)
else:
ax_image.text(0.5, 0.5, "No image available", ha="center", va="center")
ax_image.axis("off")
# 4. Recommendation panel (bottom right)
ax_rec = fig.add_subplot(gs[1, 1])
ax_rec.axis("off")
# Get recommendations
recommendations = fused_results.get("followup_recommendations", [])
# Add recommendations to the panel
y_pos = 0.9
ax_rec.text(
0.5,
y_pos,
"RECOMMENDATIONS",
fontsize=14,
fontweight="bold",
ha="center",
va="center",
)
y_pos -= 0.1
if recommendations:
for i, rec in enumerate(recommendations):
ax_rec.text(0.05, y_pos, "•", fontsize=14, ha="left", va="center")
ax_rec.text(
0.1, y_pos, rec, fontsize=11, ha="left", va="center", wrap=True
)
y_pos -= 0.15
else:
ax_rec.text(
0.1,
y_pos,
"No specific recommendations provided.",
fontsize=11,
ha="left",
va="center",
)
# Add disclaimer
fig.text(
0.5,
0.03,
"DISCLAIMER: This analysis is for informational purposes only and should not replace professional medical advice.",
fontsize=9,
style="italic",
ha="center",
)
fig.tight_layout(rect=[0, 0.05, 1, 0.95])
return fig
except Exception as e:
logger.error(f"Error plotting multimodal results: {e}")
# Create empty figure if error occurs
fig, ax = plt.subplots(figsize=(8, 6))
ax.text(0.5, 0.5, f"Error: {str(e)}", ha="center", va="center")
return fig
def figure_to_base64(fig):
"""
Convert matplotlib figure to base64 string.
Args:
fig (matplotlib.figure.Figure): Figure object
Returns:
str: Base64 encoded string
"""
try:
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
img_str = base64.b64encode(buf.read()).decode("utf-8")
return img_str
except Exception as e:
logger.error(f"Error converting figure to base64: {e}")
return ""