A19grey commited on
Commit
1b20ee8
·
1 Parent(s): a149fdb

Optimized ResNet50 model loading

Browse files
Files changed (1) hide show
  1. app.py +44 -3
app.py CHANGED
@@ -3,12 +3,53 @@ import glob
3
  import time
4
  import random
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def classify_image(image):
7
  # Wait for a random interval between 0.5 and 1.5 seconds to look useful
8
- time.sleep(random.uniform(0.5, 1.5))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # This function just returns "Not a bird" for any input image
11
- return "Not a bird"
 
 
12
 
13
  # Dynamically create the list of example images
14
  example_files = sorted(glob.glob("examples/*.png"))
 
3
  import time
4
  import random
5
 
6
+ # Import necessary libraries
7
+ from torchvision import models, transforms
8
+ from PIL import Image
9
+ import torch
10
+
11
+ # Load pre-trained ResNet model once
12
+ model = models.resnet50(pretrained=True)
13
+ model.eval()
14
+
15
+ # Define image transformations
16
+ transform = transforms.Compose([
17
+ transforms.Resize(256),
18
+ transforms.CenterCrop(224),
19
+ transforms.ToTensor(),
20
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
21
+ ])
22
+
23
+ # Load class labels
24
+ with open('imagenet_classes.txt') as f:
25
+ labels = [line.strip() for line in f.readlines()]
26
+
27
  def classify_image(image):
28
  # Wait for a random interval between 0.5 and 1.5 seconds to look useful
29
+ # time.sleep(random.uniform(0.5, 1.5))
30
+ print("Classifying image...")
31
+
32
+ # Preprocess the image
33
+ img = Image.fromarray(image).convert('RGB')
34
+ img_t = transform(img)
35
+ batch_t = torch.unsqueeze(img_t, 0)
36
+
37
+ # Make prediction
38
+ with torch.no_grad():
39
+ output = model(batch_t)
40
+
41
+ # Get the predicted class
42
+ _, predicted = torch.max(output, 1)
43
+ classification = labels[predicted.item()]
44
+
45
+ # Check if the predicted class is a bird
46
+ bird_classes = ['bird', 'fowl', 'hen', 'cock', 'rooster', 'peacock', 'parrot', 'eagle', 'owl', 'penguin']
47
+ is_bird = any(bird_class in classification.lower() for bird_class in bird_classes)
48
 
49
+ if is_bird:
50
+ return f"This is a bird! Specifically, it looks like a {classification}."
51
+ else:
52
+ return f"This is not a bird. It appears to be a {classification}."
53
 
54
  # Dynamically create the list of example images
55
  example_files = sorted(glob.glob("examples/*.png"))