File size: 3,031 Bytes
5b303e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
from PIL import Image
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from utils import *
from supervised import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load Models
models = {
    "unet": UNet(num_classes=2).to(device),
    "segformer": Segformer(num_classes=2).to(device),
    "inception": Inception(num_classes=2).to(device),
    "kmeans": KMeans(n_clusters=2),
    "gmm": GaussianMixture(n_components=2),
}

models["unet"].load_state_dict(torch.load("unet.pt", map_location=device))
models["segformer"].load_state_dict(torch.load("segformer.pt", map_location=device))
models["inception"].load_state_dict(torch.load("inception.pt", map_location=device))

for model in models.values():
    if isinstance(model, (UNet, Segformer, Inception)):
        model.eval()

# Inference function
def inference(image, model_name, postprocess_mode):
    model = models[model_name]
    status_text = f"✅ Inference with {model_name.upper()} and postprocessing mode: {postprocess_mode}"
    bw_mask, overlay = predict_and_visualize_single(model, image, postprocess_mode=postprocess_mode)
    return overlay, bw_mask, status_text

# Gradio Interface
with gr.Blocks(theme=gr.themes.Base(primary_hue="rose", secondary_hue="slate")) as demo:
    gr.Markdown("## 🩺 Skin Lesion Segmentation")
    gr.Markdown("Upload a skin image, choose a model, and view segmentation results.")

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type='numpy', label="📷 Upload Image")
            model_choice = gr.Radio(
                choices=["unet", "segformer", "inception", "kmeans", "gmm"],
                label="Model",
                value="unet"
            )
            post_choice = gr.Radio(
                choices=["none", "open", "close", "erosion", "dilation"],
                label="Postprocessing",
                value="none"
            )
            run_btn = gr.Button("▶ Run Segmentation")

        with gr.Column(scale=2):
            with gr.Row():
                overlay_output = gr.Image(type='numpy', label="🎯 Overlay")
                mask_output = gr.Image(type='numpy', label="🖤 Predicted Mask")
            status = gr.Textbox(label="Status", interactive=False)

    with gr.Row():
        gr.Examples(
            examples=["./examples/ISIC_0012880.jpg", "./examples/ISIC_0015972.jpg"],
            inputs=[image_input],
            label="Use Example Images"
        )

    with gr.Accordion("ℹ️ Legend", open=False):
        gr.Markdown("""

        - **🔴 Red**: Predicted lesion overlay  

        - **⚫ White**: Binary mask  

        - **Postprocessing**: Cleans up noisy segmentation

        """)

    run_btn.click(
        fn=inference,
        inputs=[image_input, model_choice, post_choice],
        outputs=[overlay_output, mask_output, status]
    )

demo.launch(share=True)