|
import gradio as gr
|
|
import torch
|
|
import torchvision.models as models
|
|
from huggingface_hub import hf_hub_download
|
|
import json
|
|
from PIL import Image
|
|
import torchvision.transforms as transforms
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
weights_path = hf_hub_download(repo_id="AventIQ-AI/resnet18-cataract-detection-system", filename="cataract_detection_resnet18_quantized.pth")
|
|
labels_path = hf_hub_download(repo_id="AventIQ-AI/resnet18-cataract-detection-system", filename="class_names.json")
|
|
|
|
with open(labels_path, "r") as f:
|
|
class_labels = json.load(f)
|
|
|
|
|
|
model = models.resnet18(pretrained=False)
|
|
num_classes = len(class_labels)
|
|
model.fc = torch.nn.Linear(in_features=512, out_features=num_classes)
|
|
model.load_state_dict(torch.load(weights_path, map_location=device))
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
|
|
def predict(image):
|
|
image = transform(image).unsqueeze(0).to(device)
|
|
with torch.no_grad():
|
|
output = model(image)
|
|
_, predicted = torch.max(output, 1)
|
|
predicted_class = class_labels[predicted.item()]
|
|
return {predicted_class: 1.0}
|
|
|
|
|
|
demo = gr.Interface(
|
|
fn=predict,
|
|
inputs=gr.Image(type="pil", label="Upload an Eye Image"),
|
|
outputs=gr.Label(label="Cataract Detection Result"),
|
|
title="ποΈ Cataract Detection System π₯",
|
|
description="π¬ Upload an eye image, and the AI model will determine if cataract is present! π",
|
|
theme="huggingface",
|
|
allow_flagging="never",
|
|
)
|
|
|
|
demo.launch() |