File size: 2,889 Bytes
56b0bc9
b1caa99
98632cb
6671403
 
56b0bc9
5a36ad5
b1caa99
1b20ee8
 
 
 
 
 
 
 
5a36ad5
1b20ee8
 
 
 
 
 
 
 
56b0bc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b20ee8
56b0bc9
1b20ee8
 
98632cb
568c509
1b20ee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568c509
1b20ee8
 
 
 
b1caa99
98632cb
6671403
98632cb
 
 
 
 
 
 
 
6671403
 
98632cb
 
b1caa99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import gradio as gr
import glob
import time
import random
import requests
import numpy as np

# Import necessary libraries
from torchvision import models, transforms
from PIL import Image
import torch

# Load pre-trained ResNet model once
model = models.resnet50(pretrained=True)
model.eval()
# 
# Define image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to download imagenet_classes.txt
def download_imagenet_classes():
    url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
    response = requests.get(url)
    if response.status_code == 200:
        with open("imagenet_classes.txt", "wb") as f:
            f.write(response.content)
        print("imagenet_classes.txt downloaded successfully.")
    else:
        print("Failed to download imagenet_classes.txt")

# Check if imagenet_classes.txt exists, if not, download it
if not os.path.exists("imagenet_classes.txt"):
    download_imagenet_classes()

# Load class labels
with open('imagenet_classes.txt', 'r') as f:
    labels = [line.strip() for line in f.readlines()]

def classify_image(image):
    # Wait for a random interval between 0.5 and 1.5 seconds to look useful
    # time.sleep(random.uniform(0.5, 1.5))
    print("Classifying image...")
    
    # Preprocess the image
    img = Image.fromarray(image).convert('RGB')
    img_t = transform(img)
    batch_t = torch.unsqueeze(img_t, 0)

    # Make prediction
    with torch.no_grad():
        output = model(batch_t)

    # Get the predicted class
    _, predicted = torch.max(output, 1)
    classification = labels[predicted.item()]

    # Check if the predicted class is a bird
    bird_classes = ['bird', 'fowl', 'hen', 'cock', 'rooster', 'peacock', 'parrot', 'eagle', 'owl', 'penguin']
    is_bird = any(bird_class in classification.lower() for bird_class in bird_classes)

    if is_bird:
        return f"This is a bird! Specifically, it looks like a {classification}."
    else:
        return f"This is not a bird. It appears to be a {classification}."

# Dynamically create the list of example images
example_files = sorted(glob.glob("examples/*.png"))
examples = [[file] for file in example_files]

# Create the Gradio interface
demo = gr.Interface(
    fn=classify_image,  # The function to run
    inputs="image",     # The input type is an image
    outputs="text",     # The output type is text
    examples=examples   # Add example images
    ,title="Is this a picture of a bird?"  # Title of the app
    ,description="Uses the latest in machine learning LLM Diffusion models to analyzes every pixel (twice) and to determine conclusively if it is a picture of a bird"  # Description of the app
)
# Launch the app
demo.launch()