IsItABird / app.py
A19grey's picture
add numpy import and upate reqs
5a36ad5
raw
history blame
2.89 kB
import os
import gradio as gr
import glob
import time
import random
import requests
import numpy as np
# Import necessary libraries
from torchvision import models, transforms
from PIL import Image
import torch
# Load pre-trained ResNet model once
model = models.resnet50(pretrained=True)
model.eval()
#
# Define image transformations
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Function to download imagenet_classes.txt
def download_imagenet_classes():
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
response = requests.get(url)
if response.status_code == 200:
with open("imagenet_classes.txt", "wb") as f:
f.write(response.content)
print("imagenet_classes.txt downloaded successfully.")
else:
print("Failed to download imagenet_classes.txt")
# Check if imagenet_classes.txt exists, if not, download it
if not os.path.exists("imagenet_classes.txt"):
download_imagenet_classes()
# Load class labels
with open('imagenet_classes.txt', 'r') as f:
labels = [line.strip() for line in f.readlines()]
def classify_image(image):
# Wait for a random interval between 0.5 and 1.5 seconds to look useful
# time.sleep(random.uniform(0.5, 1.5))
print("Classifying image...")
# Preprocess the image
img = Image.fromarray(image).convert('RGB')
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
# Make prediction
with torch.no_grad():
output = model(batch_t)
# Get the predicted class
_, predicted = torch.max(output, 1)
classification = labels[predicted.item()]
# Check if the predicted class is a bird
bird_classes = ['bird', 'fowl', 'hen', 'cock', 'rooster', 'peacock', 'parrot', 'eagle', 'owl', 'penguin']
is_bird = any(bird_class in classification.lower() for bird_class in bird_classes)
if is_bird:
return f"This is a bird! Specifically, it looks like a {classification}."
else:
return f"This is not a bird. It appears to be a {classification}."
# Dynamically create the list of example images
example_files = sorted(glob.glob("examples/*.png"))
examples = [[file] for file in example_files]
# Create the Gradio interface
demo = gr.Interface(
fn=classify_image, # The function to run
inputs="image", # The input type is an image
outputs="text", # The output type is text
examples=examples # Add example images
,title="Is this a picture of a bird?" # Title of the app
,description="Uses the latest in machine learning LLM Diffusion models to analyzes every pixel (twice) and to determine conclusively if it is a picture of a bird" # Description of the app
)
# Launch the app
demo.launch()