ibrahim313's picture
Update app.py
5280a25 verified
raw
history blame
12 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from huggingface_hub import hf_hub_download
import io
import requests
# Your UNET Model Definition
class UNET(nn.Module):
def __init__(self, dropout_rate=0.1, ch=32):
super(UNET, self).__init__()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def conv_block(in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout2d(p=dropout_rate),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout2d(p=dropout_rate)
)
self.encoder1 = conv_block(3, ch)
self.encoder2 = conv_block(ch, ch*2)
self.encoder3 = conv_block(ch*2, ch*4)
self.encoder4 = conv_block(ch*4, ch*8)
self.bottle_neck = conv_block(ch*8, ch*16)
self.upsample1 = nn.ConvTranspose2d(ch*16, ch*8, kernel_size=2, stride=2)
self.decoder1 = conv_block(ch*16, ch*8)
self.upsample2 = nn.ConvTranspose2d(ch*8, ch*4, kernel_size=2, stride=2)
self.decoder2 = conv_block(ch*8, ch*4)
self.upsample3 = nn.ConvTranspose2d(ch*4, ch*2, kernel_size=2, stride=2)
self.decoder3 = conv_block(ch*4, ch*2)
self.upsample4 = nn.ConvTranspose2d(ch*2, ch, kernel_size=2, stride=2)
self.decoder4 = conv_block(ch*2, ch)
self.final = nn.Conv2d(ch, 1, kernel_size=1)
def forward(self, x):
c1 = self.encoder1(x)
c2 = self.encoder2(self.pool(c1))
c3 = self.encoder3(self.pool(c2))
c4 = self.encoder4(self.pool(c3))
c5 = self.bottle_neck(self.pool(c4))
u6 = self.upsample1(c5)
u6 = torch.cat([c4, u6], dim=1)
c6 = self.decoder1(u6)
u7 = self.upsample2(c6)
u7 = torch.cat([c3, u7], dim=1)
c7 = self.decoder2(u7)
u8 = self.upsample3(c7)
u8 = torch.cat([c2, u8], dim=1)
c8 = self.decoder3(u8)
u9 = self.upsample4(c8)
u9 = torch.cat([c1, u9], dim=1)
c9 = self.decoder4(u9)
return self.final(c9)
# Global variables
model = None
device = torch.device('cpu') # HF Spaces use CPU
transform = A.Compose([
A.Resize(384, 384),
A.Normalize(mean=(0,0,0), std=(1,1,1), max_pixel_value=255),
ToTensorV2()
])
def load_model():
"""Load model from your HF repository"""
global model
try:
print("πŸ“₯ Downloading model from Hugging Face...")
# Download your model from HF
model_path = hf_hub_download(
repo_id="ibrahim313/unet-adam-diceloss",
filename="pytorch_model.bin"
)
# Load model
model = UNET(ch=32)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print("βœ… Model loaded successfully!")
return "βœ… Model loaded from ibrahim313/unet-adam-diceloss"
except Exception as e:
print(f"❌ Error loading model: {e}")
return f"❌ Error: {e}"
def predict_polyp(image, threshold=0.5):
"""Predict polyp in uploaded image"""
if model is None:
return None, "❌ Model not loaded! Please wait for model to load.", None
if image is None:
return None, "❌ Please upload an image first!", None
try:
# Convert image to numpy array
if isinstance(image, Image.Image):
original_image = np.array(image.convert('RGB'))
else:
original_image = np.array(image)
# Preprocess image
transformed = transform(image=original_image)
input_tensor = transformed['image'].unsqueeze(0).float()
# Make prediction
with torch.no_grad():
prediction = model(input_tensor)
prediction = torch.sigmoid(prediction)
prediction = (prediction > threshold).float()
# Convert to numpy
pred_mask = prediction.squeeze().cpu().numpy()
# Calculate metrics
polyp_pixels = np.sum(pred_mask)
total_pixels = pred_mask.shape[0] * pred_mask.shape[1]
polyp_percentage = (polyp_pixels / total_pixels) * 100
# Create visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Original image
axes[0].imshow(original_image)
axes[0].set_title('πŸ–ΌοΈ Original Image', fontsize=14)
axes[0].axis('off')
# Predicted mask
axes[1].imshow(pred_mask, cmap='gray')
axes[1].set_title('🎭 Predicted Mask', fontsize=14)
axes[1].axis('off')
# Overlay
axes[2].imshow(original_image)
axes[2].imshow(pred_mask, cmap='Reds', alpha=0.6)
axes[2].set_title('πŸ” Detection Overlay', fontsize=14)
axes[2].axis('off')
# Add main title with results
if polyp_pixels > 100:
main_title = f"🚨 POLYP DETECTED! Coverage: {polyp_percentage:.2f}%"
title_color = 'red'
else:
main_title = f"βœ… No Polyp Detected - Coverage: {polyp_percentage:.2f}%"
title_color = 'green'
fig.suptitle(main_title, fontsize=16, fontweight='bold', color=title_color)
plt.tight_layout()
# Save plot to image
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
buf.seek(0)
result_image = Image.open(buf)
plt.close()
# Create detailed results text
if polyp_pixels > 100:
status_emoji = "🚨"
status_text = "POLYP DETECTED"
recommendation = "⚠️ **Recommendation:** Medical review recommended"
else:
status_emoji = "βœ…"
status_text = "NO POLYP DETECTED"
recommendation = "βœ… **Recommendation:** Continue routine monitoring"
results_text = f"""
## {status_emoji} **{status_text}**
### πŸ“Š **Analysis Results:**
- **Polyp Coverage:** {polyp_percentage:.3f}%
- **Detected Pixels:** {int(polyp_pixels):,} / {total_pixels:,}
- **Detection Threshold:** {threshold}
### πŸ₯ **Clinical Assessment:**
{recommendation}
### πŸ”¬ **Technical Details:**
- **Model:** U-Net (32 channels)
- **Input Size:** 384Γ—384 pixels
- **Architecture:** Encoder-Decoder with skip connections
"""
return result_image, results_text, pred_mask
except Exception as e:
error_msg = f"❌ **Error processing image:** {str(e)}"
return None, error_msg, None
def load_example_image(image_num):
"""Load example images from your HF space"""
try:
if image_num == 1:
# Image 1: cju0qoxqj9q6s0835b43399p4.jpg
image_path = hf_hub_download(
repo_id="ibrahim313/unet-adam-diceloss",
filename="cju0qoxqj9q6s0835b43399p4.jpg",
repo_type="space"
)
else:
# Image 2: cju0roawvklrq0799vmjorwfv.jpg
image_path = hf_hub_download(
repo_id="ibrahim313/unet-adam-diceloss",
filename="cju0roawvklrq0799vmjorwfv.jpg",
repo_type="space"
)
# Load and return the image
image = Image.open(image_path)
return image
except Exception as e:
print(f"Error loading example image {image_num}: {e}")
return None
# Load model when app starts
print("πŸš€ Starting Polyp Detection App...")
load_status = load_model()
print(load_status)
# Create Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="πŸ₯ Polyp Detection AI") as demo:
# Header
gr.HTML("""
<div style="text-align: center; padding: 30px; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px;">
<h1 style="margin: 0; font-size: 2.5em;">πŸ₯ AI Polyp Detection System</h1>
<p style="margin: 10px 0 0 0; font-size: 1.2em;">Advanced Medical Imaging with Deep Learning</p>
<p style="margin: 5px 0 0 0; opacity: 0.9;">Upload colonoscopy images for intelligent polyp detection</p>
</div>
""")
# Model info
gr.HTML(f"""
<div style="background: black; padding: 15px; border-radius: 8px; border-left: 4px solid #0ea5e9; margin-bottom: 20px;">
<strong>πŸ”¬ Model:</strong> ibrahim313/unet-adam-diceloss<br>
<strong>πŸ“ Architecture:</strong> U-Net with 32 base channels<br>
<strong>🎯 Dataset:</strong> Trained on Kvasir-SEG (1000 polyp images)<br>
<strong>πŸ“Έ Examples:</strong> 2 test colonoscopy images included<br>
<strong>⚑ Status:</strong> {load_status}
</div>
""")
# Main interface
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“€ Upload Image</h3>")
input_image = gr.Image(
label="Drop colonoscopy image here",
type="pil",
height=300
)
threshold_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.5,
step=0.1,
label="🎯 Detection Sensitivity",
info="Higher = more sensitive detection"
)
analyze_btn = gr.Button(
"πŸ” Analyze for Polyps",
variant="primary",
size="lg"
)
gr.HTML("<br>")
# Quick examples
gr.HTML("<h4>πŸ“Έ Try Sample Images:</h4>")
gr.HTML("<p style='font-size: 0.9em; color: #666; margin: 5px 0;'>Click to load colonoscopy test images</p>")
with gr.Row():
example1_btn = gr.Button("πŸ–ΌοΈ Test Image 1", size="sm", variant="secondary")
example2_btn = gr.Button("πŸ–ΌοΈ Test Image 2", size="sm", variant="secondary")
with gr.Column(scale=2):
gr.HTML("<h3>πŸ“Š Detection Results</h3>")
output_image = gr.Image(
label="Analysis Results",
height=400
)
results_text = gr.Markdown(
value="Upload an image and click 'Analyze for Polyps' to see results.",
label="Detailed Analysis"
)
# Event handlers
analyze_btn.click(
fn=predict_polyp,
inputs=[input_image, threshold_slider],
outputs=[output_image, results_text, gr.State()]
)
# Example button handlers
example1_btn.click(
fn=lambda: load_example_image(1),
inputs=[],
outputs=[input_image]
)
example2_btn.click(
fn=lambda: load_example_image(2),
inputs=[],
outputs=[input_image]
)
# Footer
gr.HTML("""
<div style="text-align: center; padding: 20px; margin-top: 40px; border-top: 2px solid #e5e7eb; background: #f9fafb;">
<p style="margin: 0; color: #dc2626; font-weight: bold;">
⚠️ MEDICAL DISCLAIMER
</p>
<p style="margin: 5px 0; color: #4b5563;">
This AI system is for research and educational purposes only.<br>
Always consult qualified medical professionals for clinical decisions.
</p>
<p style="margin: 10px 0 0 0; color: #6b7280; font-size: 0.9em;">
πŸ”¬ Powered by PyTorch | πŸ€— Hosted on Hugging Face | πŸ“Š Gradio Interface
</p>
</div>
""")
# Launch the app
if __name__ == "__main__":
demo.launch()