btdetection / app.py
chrisaldikaraharja's picture
Update app.py
6f6031c verified
import tensorflow as tf
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
from PIL import Image
from fpdf import FPDF
import requests
import os
# Hugging Face model URL
model_url = "https://huggingface.co/chrisaldikaraharja/BrainTumor-Model/resolve/main/brain_tumor_detection_model.h5"
local_h5_path = "brain_tumor_detection_model.h5"
saved_model_dir = "saved_brain_tumor_model.keras" # Use .keras extension
# Function to download the model if not present locally
def download_model():
if not os.path.exists(local_h5_path):
st.info("Downloading model...")
response = requests.get(model_url)
with open(local_h5_path, 'wb') as file:
file.write(response.content)
st.success("Model downloaded successfully.")
else:
st.info("Model already downloaded.")
# Function to load and save the model in the SavedModel format
def convert_and_save_model(h5_path, save_dir):
try:
model = tf.keras.models.load_model(h5_path)
model.save(save_dir) # Save in .keras format
st.success(f"Model saved in .keras format at '{save_dir}'.")
return model
except Exception as e:
st.error(f"Error saving model: {e}")
return None
# Function to load the saved model
def load_saved_model(saved_model_path):
try:
model = tf.keras.models.load_model(saved_model_path)
return model
except Exception as e:
st.error(f"Error loading the model: {e}")
return None
# Download the model if not already done
# Check if the model is being downloaded and saved correctly
download_model()
# Convert and save the model if it doesn't exist
if not os.path.exists(saved_model_dir):
model = convert_and_save_model(local_h5_path, saved_model_dir)
else:
model = load_saved_model(saved_model_dir)
# Check if the model was loaded successfully
if model is None:
st.error("Failed to load the model. Please check the model file.")
else:
st.success("Model loaded successfully.")
# Define the class labels
label_map = {0: 'Glioma', 1: 'Meningioma', 2: 'Normal', 3: 'Pituitary'}
# Function to preprocess image
def preprocess_image(image):
try:
img = image.resize((150, 150)) # Resize to match model input shape
img = np.array(img) / 255.0 # Normalize pixel values
img = np.expand_dims(img, axis=0) # Add batch dimension
return img
except Exception as e:
st.error(f"Error in preprocessing image: {e}")
raise
# Function to compute Occlusion Sensitivity
def compute_occlusion_sensitivity(model, img_array, class_idx, patch_size=15):
img_array_np = img_array.squeeze()
img_shape = img_array_np.shape
original_img = np.copy(img_array_np)
original_prediction = model.predict(np.expand_dims(original_img, axis=0))[0, class_idx]
sensitivity_map = np.zeros((img_shape[0], img_shape[1]))
for i in range(0, img_shape[0], patch_size):
for j in range(0, img_shape[1], patch_size):
img_copy = np.copy(original_img)
img_copy[i:i+patch_size, j:j+patch_size, :] = 0 # Occlude patch
occluded_prediction = model.predict(np.expand_dims(img_copy, axis=0))[0, class_idx]
sensitivity_map[i:i+patch_size, j:j+patch_size] = original_prediction - occluded_prediction
return sensitivity_map
# Function to visualize the occlusion map
def visualize_occlusion_map(img, occlusion_map):
plt.figure(figsize=(10, 10))
plt.imshow(img)
plt.imshow(occlusion_map, cmap='hot', alpha=0.5)
plt.axis('off')
plt.colorbar()
occlusion_path = "/tmp/occlusion_map.png"
plt.savefig(occlusion_path)
plt.close()
return occlusion_path
# Function to generate PDF report
def create_pdf_report(prediction_text, doctor_notes, image_path):
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
pdf.multi_cell(0, 10, prediction_text)
pdf.ln(10)
pdf.set_font("Arial", 'B', size=12)
pdf.cell(0, 10, "Doctor's Notes:")
pdf.ln(10)
pdf.set_font("Arial", size=12)
pdf.multi_cell(0, 10, doctor_notes)
pdf.image(image_path, x=10, y=pdf.get_y(), w=pdf.w - 20)
pdf_output_path = "/tmp/medical_report.pdf"
pdf.output(pdf_output_path)
return pdf_output_path
# Main prediction and report generation function
def predict_and_generate_report(image, doctor_notes):
try:
img_array = preprocess_image(image)
predictions = model.predict(img_array)[0]
class_idx = np.argmax(predictions)
class_label = label_map[class_idx]
# Prediction text
prediction_results = [f"{label_map[i]}: {predictions[i] * 100:.2f}%" for i in range(len(predictions))]
prediction_text = f"Prediction: {class_label}\n" + "\n".join(prediction_results)
# Save the uploaded image
image_path = "/tmp/image.png"
image.save(image_path)
# Generate PDF report
pdf_report_path = create_pdf_report(prediction_text, doctor_notes, image_path)
# Compute and visualize occlusion sensitivity
occlusion_map = compute_occlusion_sensitivity(model, img_array, class_idx)
occlusion_map_path = visualize_occlusion_map(np.array(image), occlusion_map)
return image, prediction_text, pdf_report_path, occlusion_map_path
except Exception as e:
st.error(f"Error during prediction: {e}")
return None, "Error during prediction. Please check the input image.", None, None
# Streamlit Interface
st.title("Brain Tumor Detection and Occlusion Sensitivity")
st.write("Upload an MRI image to detect brain tumors and view which areas contributed to the prediction.")
# Image upload
uploaded_image = st.file_uploader("Upload MRI Image", type=["png", "jpg", "jpeg"])
doctor_notes = st.text_area("Enter doctor's notes here...", "")
if st.button("Predict"):
if uploaded_image is not None:
image = Image.open(uploaded_image)
drawn_image, prediction_text, pdf_report_path, occlusion_map_path = predict_and_generate_report(image, doctor_notes)
if drawn_image is not None:
st.image(drawn_image, caption="Original MRI Image", use_column_width=True)
st.text(prediction_text)
# Provide PDF report download
st.download_button("Download Medical Report", data=open(pdf_report_path, "rb").read(), file_name="medical_report.pdf")
# Display Occlusion Sensitivity Map
occlusion_image = Image.open(occlusion_map_path)
st.image(occlusion_image, caption="Occlusion Sensitivity Map", use_column_width=True)
else:
st.warning("Please upload an MRI image to proceed.")