Vbai-TS-1.0 / app.py
eyupipler's picture
Update app.py
c176ccc verified
import gradio as gr
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from model import load_model
import matplotlib.pyplot as plt
import numpy as np
from thop import profile
import io
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
models_cache = {}
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
class_names = [
'Glioma Tumor',
'Meningioma Tumor',
'No Tumor',
'Pituitary Tumor'
]
def calculate_performance(model):
model.eval()
dummy = torch.randn(1,3,224,224).to(device)
flops, params = profile(model, inputs=(dummy,), verbose=False)
params_m = round(params/1e6,2)
flops_b = round(flops/1e9,2)
import time
start = time.time()
_ = model(dummy.cpu())
cpu_ms = round((time.time() - start)*1000,2)
if device.type == 'cuda':
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
_ = model(dummy)
end_event.record()
torch.cuda.synchronize()
gpu_ms = round(start_event.elapsed_time(end_event),2)
else:
gpu_ms = None
return {'params_million':params_m, 'flops_billion':flops_b, 'cpu_ms':cpu_ms, 'gpu_ms':gpu_ms}
def predict_and_monitor(version, image):
try:
if version not in models_cache:
models_cache[version] = load_model(version, device)
model = models_cache[version]
if image is None:
raise gr.Error("Görsel yüklenmedi.")
img = image.convert("RGB")
tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(tensor)
probs = F.softmax(logits, dim=1)[0]
pred_dict = {class_names[i]: round(float(probs[i]),4) for i in range(len(class_names))}
metrics = calculate_performance(model)
top1 = max(pred_dict, key=pred_dict.get)
buf = io.BytesIO()
plt.figure(figsize=(3,3))
plt.imshow(img)
plt.title(f"{top1}: {pred_dict[top1]*100:.1f}%")
plt.axis('off')
plt.savefig(buf, format='png')
plt.close()
buf.seek(0)
buf_image = Image.open(buf)
return pred_dict, metrics, buf_image
except Exception as e:
raise gr.Error(f"Prediction Error: {e}")
with gr.Blocks() as demo:
gr.Markdown("Tumor Diagnosis with Vbai-TS 1.0(f,c)")
with gr.Row():
version = gr.Radio(['f','c'], value='c', label="Model Version | f => Fastest, c => Classic")
image_in = gr.Image(type="pil", label="MRI or fMRI Image")
with gr.Row():
preds = gr.JSON(label="Prediction Probabilities")
stats = gr.JSON(label="Performance Metrics")
plot = gr.Image(label="Prediction")
btn = gr.Button("Run")
btn.click(fn=predict_and_monitor, inputs=[version, image_in], outputs=[preds, stats, plot])
if __name__ == '__main__':
demo.launch()