Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
import numpy as np | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from fpdf import FPDF | |
import requests | |
import os | |
# Hugging Face model URL | |
model_url = "https://huggingface.co/chrisaldikaraharja/BrainTumor-Model/resolve/main/brain_tumor_detection_model.h5" | |
local_h5_path = "brain_tumor_detection_model.h5" | |
saved_model_dir = "saved_brain_tumor_model.keras" # Use .keras extension | |
# Function to download the model if not present locally | |
def download_model(): | |
if not os.path.exists(local_h5_path): | |
st.info("Downloading model...") | |
response = requests.get(model_url) | |
with open(local_h5_path, 'wb') as file: | |
file.write(response.content) | |
st.success("Model downloaded successfully.") | |
else: | |
st.info("Model already downloaded.") | |
# Function to load and save the model in the SavedModel format | |
def convert_and_save_model(h5_path, save_dir): | |
try: | |
model = tf.keras.models.load_model(h5_path) | |
model.save(save_dir) # Save in .keras format | |
st.success(f"Model saved in .keras format at '{save_dir}'.") | |
return model | |
except Exception as e: | |
st.error(f"Error saving model: {e}") | |
return None | |
# Function to load the saved model | |
def load_saved_model(saved_model_path): | |
try: | |
model = tf.keras.models.load_model(saved_model_path) | |
return model | |
except Exception as e: | |
st.error(f"Error loading the model: {e}") | |
return None | |
# Download the model if not already done | |
# Check if the model is being downloaded and saved correctly | |
download_model() | |
# Convert and save the model if it doesn't exist | |
if not os.path.exists(saved_model_dir): | |
model = convert_and_save_model(local_h5_path, saved_model_dir) | |
else: | |
model = load_saved_model(saved_model_dir) | |
# Check if the model was loaded successfully | |
if model is None: | |
st.error("Failed to load the model. Please check the model file.") | |
else: | |
st.success("Model loaded successfully.") | |
# Define the class labels | |
label_map = {0: 'Glioma', 1: 'Meningioma', 2: 'Normal', 3: 'Pituitary'} | |
# Function to preprocess image | |
def preprocess_image(image): | |
try: | |
img = image.resize((150, 150)) # Resize to match model input shape | |
img = np.array(img) / 255.0 # Normalize pixel values | |
img = np.expand_dims(img, axis=0) # Add batch dimension | |
return img | |
except Exception as e: | |
st.error(f"Error in preprocessing image: {e}") | |
raise | |
# Function to compute Occlusion Sensitivity | |
def compute_occlusion_sensitivity(model, img_array, class_idx, patch_size=15): | |
img_array_np = img_array.squeeze() | |
img_shape = img_array_np.shape | |
original_img = np.copy(img_array_np) | |
original_prediction = model.predict(np.expand_dims(original_img, axis=0))[0, class_idx] | |
sensitivity_map = np.zeros((img_shape[0], img_shape[1])) | |
for i in range(0, img_shape[0], patch_size): | |
for j in range(0, img_shape[1], patch_size): | |
img_copy = np.copy(original_img) | |
img_copy[i:i+patch_size, j:j+patch_size, :] = 0 # Occlude patch | |
occluded_prediction = model.predict(np.expand_dims(img_copy, axis=0))[0, class_idx] | |
sensitivity_map[i:i+patch_size, j:j+patch_size] = original_prediction - occluded_prediction | |
return sensitivity_map | |
# Function to visualize the occlusion map | |
def visualize_occlusion_map(img, occlusion_map): | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(img) | |
plt.imshow(occlusion_map, cmap='hot', alpha=0.5) | |
plt.axis('off') | |
plt.colorbar() | |
occlusion_path = "/tmp/occlusion_map.png" | |
plt.savefig(occlusion_path) | |
plt.close() | |
return occlusion_path | |
# Function to generate PDF report | |
def create_pdf_report(prediction_text, doctor_notes, image_path): | |
pdf = FPDF() | |
pdf.add_page() | |
pdf.set_font("Arial", size=12) | |
pdf.multi_cell(0, 10, prediction_text) | |
pdf.ln(10) | |
pdf.set_font("Arial", 'B', size=12) | |
pdf.cell(0, 10, "Doctor's Notes:") | |
pdf.ln(10) | |
pdf.set_font("Arial", size=12) | |
pdf.multi_cell(0, 10, doctor_notes) | |
pdf.image(image_path, x=10, y=pdf.get_y(), w=pdf.w - 20) | |
pdf_output_path = "/tmp/medical_report.pdf" | |
pdf.output(pdf_output_path) | |
return pdf_output_path | |
# Main prediction and report generation function | |
def predict_and_generate_report(image, doctor_notes): | |
try: | |
img_array = preprocess_image(image) | |
predictions = model.predict(img_array)[0] | |
class_idx = np.argmax(predictions) | |
class_label = label_map[class_idx] | |
# Prediction text | |
prediction_results = [f"{label_map[i]}: {predictions[i] * 100:.2f}%" for i in range(len(predictions))] | |
prediction_text = f"Prediction: {class_label}\n" + "\n".join(prediction_results) | |
# Save the uploaded image | |
image_path = "/tmp/image.png" | |
image.save(image_path) | |
# Generate PDF report | |
pdf_report_path = create_pdf_report(prediction_text, doctor_notes, image_path) | |
# Compute and visualize occlusion sensitivity | |
occlusion_map = compute_occlusion_sensitivity(model, img_array, class_idx) | |
occlusion_map_path = visualize_occlusion_map(np.array(image), occlusion_map) | |
return image, prediction_text, pdf_report_path, occlusion_map_path | |
except Exception as e: | |
st.error(f"Error during prediction: {e}") | |
return None, "Error during prediction. Please check the input image.", None, None | |
# Streamlit Interface | |
st.title("Brain Tumor Detection and Occlusion Sensitivity") | |
st.write("Upload an MRI image to detect brain tumors and view which areas contributed to the prediction.") | |
# Image upload | |
uploaded_image = st.file_uploader("Upload MRI Image", type=["png", "jpg", "jpeg"]) | |
doctor_notes = st.text_area("Enter doctor's notes here...", "") | |
if st.button("Predict"): | |
if uploaded_image is not None: | |
image = Image.open(uploaded_image) | |
drawn_image, prediction_text, pdf_report_path, occlusion_map_path = predict_and_generate_report(image, doctor_notes) | |
if drawn_image is not None: | |
st.image(drawn_image, caption="Original MRI Image", use_column_width=True) | |
st.text(prediction_text) | |
# Provide PDF report download | |
st.download_button("Download Medical Report", data=open(pdf_report_path, "rb").read(), file_name="medical_report.pdf") | |
# Display Occlusion Sensitivity Map | |
occlusion_image = Image.open(occlusion_map_path) | |
st.image(occlusion_image, caption="Occlusion Sensitivity Map", use_column_width=True) | |
else: | |
st.warning("Please upload an MRI image to proceed.") | |