File size: 3,090 Bytes
539c115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c176ccc
539c115
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from model import load_model
import matplotlib.pyplot as plt
import numpy as np
from thop import profile
import io

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

models_cache = {}

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

class_names = [
    'Glioma Tumor',
    'Meningioma Tumor',
    'No Tumor',
    'Pituitary Tumor'
]

def calculate_performance(model):
    model.eval()
    dummy = torch.randn(1,3,224,224).to(device)
    flops, params = profile(model, inputs=(dummy,), verbose=False)
    params_m = round(params/1e6,2)
    flops_b = round(flops/1e9,2)
    import time
    start = time.time()
    _ = model(dummy.cpu())
    cpu_ms = round((time.time() - start)*1000,2)
    if device.type == 'cuda':
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()
        _ = model(dummy)
        end_event.record()
        torch.cuda.synchronize()
        gpu_ms = round(start_event.elapsed_time(end_event),2)
    else:
        gpu_ms = None
    return {'params_million':params_m, 'flops_billion':flops_b, 'cpu_ms':cpu_ms, 'gpu_ms':gpu_ms}

def predict_and_monitor(version, image):
    try:
        if version not in models_cache:
            models_cache[version] = load_model(version, device)
        model = models_cache[version]

        if image is None:
            raise gr.Error("Görsel yüklenmedi.")
        img = image.convert("RGB")
        tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            logits = model(tensor)
            probs = F.softmax(logits, dim=1)[0]

        pred_dict = {class_names[i]: round(float(probs[i]),4) for i in range(len(class_names))}
        metrics = calculate_performance(model)

        top1 = max(pred_dict, key=pred_dict.get)
        buf = io.BytesIO()
        plt.figure(figsize=(3,3))
        plt.imshow(img)
        plt.title(f"{top1}: {pred_dict[top1]*100:.1f}%")
        plt.axis('off')
        plt.savefig(buf, format='png')
        plt.close()
        buf.seek(0)
        buf_image = Image.open(buf)
        return pred_dict, metrics, buf_image
    except Exception as e:
        raise gr.Error(f"Prediction Error: {e}")

with gr.Blocks() as demo:
    gr.Markdown("Tumor Diagnosis with Vbai-TS 1.0(f,c)")
    with gr.Row():
        version = gr.Radio(['f','c'], value='c', label="Model Version | f => Fastest, c => Classic")
        image_in = gr.Image(type="pil", label="MRI or fMRI Image")
    with gr.Row():
        preds = gr.JSON(label="Prediction Probabilities")
        stats = gr.JSON(label="Performance Metrics")
        plot = gr.Image(label="Prediction")
    btn = gr.Button("Run")
    btn.click(fn=predict_and_monitor, inputs=[version, image_in], outputs=[preds, stats, plot])

if __name__ == '__main__':
    demo.launch()