File size: 2,889 Bytes
56b0bc9 b1caa99 98632cb 6671403 56b0bc9 5a36ad5 b1caa99 1b20ee8 5a36ad5 1b20ee8 56b0bc9 1b20ee8 56b0bc9 1b20ee8 98632cb 568c509 1b20ee8 568c509 1b20ee8 b1caa99 98632cb 6671403 98632cb 6671403 98632cb b1caa99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
import gradio as gr
import glob
import time
import random
import requests
import numpy as np
# Import necessary libraries
from torchvision import models, transforms
from PIL import Image
import torch
# Load pre-trained ResNet model once
model = models.resnet50(pretrained=True)
model.eval()
#
# Define image transformations
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Function to download imagenet_classes.txt
def download_imagenet_classes():
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
response = requests.get(url)
if response.status_code == 200:
with open("imagenet_classes.txt", "wb") as f:
f.write(response.content)
print("imagenet_classes.txt downloaded successfully.")
else:
print("Failed to download imagenet_classes.txt")
# Check if imagenet_classes.txt exists, if not, download it
if not os.path.exists("imagenet_classes.txt"):
download_imagenet_classes()
# Load class labels
with open('imagenet_classes.txt', 'r') as f:
labels = [line.strip() for line in f.readlines()]
def classify_image(image):
# Wait for a random interval between 0.5 and 1.5 seconds to look useful
# time.sleep(random.uniform(0.5, 1.5))
print("Classifying image...")
# Preprocess the image
img = Image.fromarray(image).convert('RGB')
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
# Make prediction
with torch.no_grad():
output = model(batch_t)
# Get the predicted class
_, predicted = torch.max(output, 1)
classification = labels[predicted.item()]
# Check if the predicted class is a bird
bird_classes = ['bird', 'fowl', 'hen', 'cock', 'rooster', 'peacock', 'parrot', 'eagle', 'owl', 'penguin']
is_bird = any(bird_class in classification.lower() for bird_class in bird_classes)
if is_bird:
return f"This is a bird! Specifically, it looks like a {classification}."
else:
return f"This is not a bird. It appears to be a {classification}."
# Dynamically create the list of example images
example_files = sorted(glob.glob("examples/*.png"))
examples = [[file] for file in example_files]
# Create the Gradio interface
demo = gr.Interface(
fn=classify_image, # The function to run
inputs="image", # The input type is an image
outputs="text", # The output type is text
examples=examples # Add example images
,title="Is this a picture of a bird?" # Title of the app
,description="Uses the latest in machine learning LLM Diffusion models to analyzes every pixel (twice) and to determine conclusively if it is a picture of a bird" # Description of the app
)
# Launch the app
demo.launch() |