Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from PIL import Image | |
from model import load_model | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from thop import profile | |
import io | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
models_cache = {} | |
transform = transforms.Compose([ | |
transforms.Resize((224,224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) | |
]) | |
class_names = [ | |
'Glioma Tumor', | |
'Meningioma Tumor', | |
'No Tumor', | |
'Pituitary Tumor' | |
] | |
def calculate_performance(model): | |
model.eval() | |
dummy = torch.randn(1,3,224,224).to(device) | |
flops, params = profile(model, inputs=(dummy,), verbose=False) | |
params_m = round(params/1e6,2) | |
flops_b = round(flops/1e9,2) | |
import time | |
start = time.time() | |
_ = model(dummy.cpu()) | |
cpu_ms = round((time.time() - start)*1000,2) | |
if device.type == 'cuda': | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
start_event.record() | |
_ = model(dummy) | |
end_event.record() | |
torch.cuda.synchronize() | |
gpu_ms = round(start_event.elapsed_time(end_event),2) | |
else: | |
gpu_ms = None | |
return {'params_million':params_m, 'flops_billion':flops_b, 'cpu_ms':cpu_ms, 'gpu_ms':gpu_ms} | |
def predict_and_monitor(version, image): | |
try: | |
if version not in models_cache: | |
models_cache[version] = load_model(version, device) | |
model = models_cache[version] | |
if image is None: | |
raise gr.Error("Görsel yüklenmedi.") | |
img = image.convert("RGB") | |
tensor = transform(img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
logits = model(tensor) | |
probs = F.softmax(logits, dim=1)[0] | |
pred_dict = {class_names[i]: round(float(probs[i]),4) for i in range(len(class_names))} | |
metrics = calculate_performance(model) | |
top1 = max(pred_dict, key=pred_dict.get) | |
buf = io.BytesIO() | |
plt.figure(figsize=(3,3)) | |
plt.imshow(img) | |
plt.title(f"{top1}: {pred_dict[top1]*100:.1f}%") | |
plt.axis('off') | |
plt.savefig(buf, format='png') | |
plt.close() | |
buf.seek(0) | |
buf_image = Image.open(buf) | |
return pred_dict, metrics, buf_image | |
except Exception as e: | |
raise gr.Error(f"Prediction Error: {e}") | |
with gr.Blocks() as demo: | |
gr.Markdown("Tumor Diagnosis with Vbai-TS 1.0(f,c)") | |
with gr.Row(): | |
version = gr.Radio(['f','c'], value='c', label="Model Version | f => Fastest, c => Classic") | |
image_in = gr.Image(type="pil", label="MRI or fMRI Image") | |
with gr.Row(): | |
preds = gr.JSON(label="Prediction Probabilities") | |
stats = gr.JSON(label="Performance Metrics") | |
plot = gr.Image(label="Prediction") | |
btn = gr.Button("Run") | |
btn.click(fn=predict_and_monitor, inputs=[version, image_in], outputs=[preds, stats, plot]) | |
if __name__ == '__main__': | |
demo.launch() |