import os import gradio as gr import glob import time import random import requests import numpy as np # Import necessary libraries from torchvision import models, transforms from PIL import Image import torch # Load pre-trained ResNet model once model = models.resnet50(pretrained=True) model.eval() # # Define image transformations transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Function to download imagenet_classes.txt def download_imagenet_classes(): url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" response = requests.get(url) if response.status_code == 200: with open("imagenet_classes.txt", "wb") as f: f.write(response.content) print("imagenet_classes.txt downloaded successfully.") else: print("Failed to download imagenet_classes.txt") # Check if imagenet_classes.txt exists, if not, download it if not os.path.exists("imagenet_classes.txt"): download_imagenet_classes() # Load class labels with open('imagenet_classes.txt', 'r') as f: labels = [line.strip() for line in f.readlines()] def classify_image(image): # Wait for a random interval between 0.5 and 1.5 seconds to look useful # time.sleep(random.uniform(0.5, 1.5)) print("Classifying image...") # Preprocess the image img = Image.fromarray(image).convert('RGB') img_t = transform(img) batch_t = torch.unsqueeze(img_t, 0) # Make prediction with torch.no_grad(): output = model(batch_t) # Get the predicted class _, predicted = torch.max(output, 1) classification = labels[predicted.item()] # Check if the predicted class is a bird bird_classes = ['bird', 'fowl', 'hen', 'cock', 'rooster', 'peacock', 'parrot', 'eagle', 'owl', 'penguin'] is_bird = any(bird_class in classification.lower() for bird_class in bird_classes) if is_bird: return f"This is a bird! Specifically, it looks like a {classification}." else: return f"This is not a bird. It appears to be a {classification}." # Dynamically create the list of example images example_files = sorted(glob.glob("examples/*.png")) examples = [[file] for file in example_files] # Create the Gradio interface demo = gr.Interface( fn=classify_image, # The function to run inputs="image", # The input type is an image outputs="text", # The output type is text examples=examples # Add example images ,title="Is this a picture of a bird?" # Title of the app ,description="Uses the latest in machine learning LLM Diffusion models to analyzes every pixel (twice) and to determine conclusively if it is a picture of a bird" # Description of the app ) # Launch the app demo.launch()