import gradio as gr import torch import torchvision.models as models from huggingface_hub import hf_hub_download import json from PIL import Image import torchvision.transforms as transforms # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # Load model weights and class labels weights_path = hf_hub_download(repo_id="AventIQ-AI/resnet18-cataract-detection-system", filename="cataract_detection_resnet18_quantized.pth") labels_path = hf_hub_download(repo_id="AventIQ-AI/resnet18-cataract-detection-system", filename="class_names.json") with open(labels_path, "r") as f: class_labels = json.load(f) # Load model model = models.resnet18(pretrained=False) num_classes = len(class_labels) model.fc = torch.nn.Linear(in_features=512, out_features=num_classes) model.load_state_dict(torch.load(weights_path, map_location=device)) model.to(device) model.eval() # Define transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Prediction function def predict(image): image = transform(image).unsqueeze(0).to(device) # Preprocess image with torch.no_grad(): output = model(image) _, predicted = torch.max(output, 1) predicted_class = class_labels[predicted.item()] return {predicted_class: 1.0} # Confidence is assumed to be 1.0 for simplicity # Gradio Interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload an Eye Image"), outputs=gr.Label(label="Cataract Detection Result"), title="👁️ Cataract Detection System 🏥", description="🔬 Upload an eye image, and the AI model will determine if cataract is present! 👓", theme="huggingface", allow_flagging="never", ) demo.launch()