import gradio as gr import torch import torchvision.models as models from huggingface_hub import hf_hub_download import json from PIL import Image import torchvision.transforms as transforms # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # Load model weights and class labels weights_path = hf_hub_download(repo_id="AventIQ-AI/resnet18-sports-category-classification", filename="resnet18_sports_classification.pth") labels_path = hf_hub_download(repo_id="AventIQ-AI/resnet18-sports-category-classification", filename="class_labels.json") with open(labels_path, "r") as f: class_labels = json.load(f) # Load model model = models.resnet18(pretrained=False) num_classes = len(class_labels) model.fc = torch.nn.Linear(in_features=512, out_features=num_classes) model.load_state_dict(torch.load(weights_path, map_location=device)) model.to(device) model.eval() # Define transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Prediction function def predict(image): image = transform(image).unsqueeze(0).to(device) # Preprocess image with torch.no_grad(): output = model(image) _, predicted = torch.max(output, 1) predicted_class = class_labels[predicted.item()] return {predicted_class: 1.0} # Confidence is assumed to be 1.0 for simplicity # Gradio Interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload a Sports Image"), outputs=gr.Label(label="Predicted Sport Category"), title="🏆 Sports Category Classifier 🏅", description="🎯 Upload an image of a sports activity, and the AI model will classify it into the correct category! 🏀⚽🏓", theme="huggingface", allow_flagging="never", ) demo.launch()