import gradio as gr import torch from PIL import Image from sklearn.cluster import KMeans from sklearn.mixture import GaussianMixture from utils import * from supervised import * device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load Models models = { "unet": UNet(num_classes=2).to(device), "segformer": Segformer(num_classes=2).to(device), "inception": Inception(num_classes=2).to(device), "kmeans": KMeans(n_clusters=2), "gmm": GaussianMixture(n_components=2), } models["unet"].load_state_dict(torch.load("unet.pt", map_location=device)) models["segformer"].load_state_dict(torch.load("segformer.pt", map_location=device)) models["inception"].load_state_dict(torch.load("inception.pt", map_location=device)) for model in models.values(): if isinstance(model, (UNet, Segformer, Inception)): model.eval() # Inference function def inference(image, model_name, postprocess_mode): model = models[model_name] status_text = f"✅ Inference with {model_name.upper()} and postprocessing mode: {postprocess_mode}" bw_mask, overlay = predict_and_visualize_single(model, image, postprocess_mode=postprocess_mode) return overlay, bw_mask, status_text # Gradio Interface with gr.Blocks(theme=gr.themes.Base(primary_hue="rose", secondary_hue="slate")) as demo: gr.Markdown("## 🩺 Skin Lesion Segmentation") gr.Markdown("Upload a skin image, choose a model, and view segmentation results.") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type='numpy', label="📷 Upload Image") model_choice = gr.Radio( choices=["unet", "segformer", "inception", "kmeans", "gmm"], label="Model", value="unet" ) post_choice = gr.Radio( choices=["none", "open", "close", "erosion", "dilation"], label="Postprocessing", value="none" ) run_btn = gr.Button("▶ Run Segmentation") with gr.Column(scale=2): with gr.Row(): overlay_output = gr.Image(type='numpy', label="🎯 Overlay") mask_output = gr.Image(type='numpy', label="🖤 Predicted Mask") status = gr.Textbox(label="Status", interactive=False) with gr.Row(): gr.Examples( examples=["./examples/ISIC_0012880.jpg", "./examples/ISIC_0015972.jpg"], inputs=[image_input], label="Use Example Images" ) with gr.Accordion("ℹ️ Legend", open=False): gr.Markdown(""" - **🔴 Red**: Predicted lesion overlay - **⚫ White**: Binary mask - **Postprocessing**: Cleans up noisy segmentation """) run_btn.click( fn=inference, inputs=[image_input, model_choice, post_choice], outputs=[overlay_output, mask_output, status] ) demo.launch(share=True)