Spaces:
Sleeping
Sleeping
app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import pickle
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from torchvision.transforms import ToTensor, RandomErasing
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import numpy as np
|
| 8 |
+
from ultralytics import YOLO
|
| 9 |
+
import io
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Load the YOLO model for bird detection
|
| 13 |
+
yolo_model = YOLO('yolov5su.pt')
|
| 14 |
+
|
| 15 |
+
class CPU_Unpickler(pickle.Unpickler):
|
| 16 |
+
def find_class(self, module, name):
|
| 17 |
+
if module == 'torch.storage' and name == '_load_from_bytes':
|
| 18 |
+
return lambda b: torch.load(io.BytesIO(b), map_location='cpu') # Ensure loading on CPU
|
| 19 |
+
else:
|
| 20 |
+
return super().find_class(module, name)
|
| 21 |
+
|
| 22 |
+
# Load your model using the custom unpickler
|
| 23 |
+
with open(".\model\model_resultsconvnext_large.pkl", "rb") as file:
|
| 24 |
+
model = CPU_Unpickler(file).load()
|
| 25 |
+
model = model['convnext_large']['model']
|
| 26 |
+
model.eval()
|
| 27 |
+
# Function to detect bird region
|
| 28 |
+
def detect_bird_region(image):
|
| 29 |
+
results = yolo_model(image, verbose=False)
|
| 30 |
+
bird_boxes = results[0].boxes[results[0].boxes.cls == 14]
|
| 31 |
+
if len(bird_boxes) > 0:
|
| 32 |
+
return bird_boxes[0].xyxy[0].cpu().numpy() # Coordinates of the first detected bird
|
| 33 |
+
return None
|
| 34 |
+
|
| 35 |
+
# Preprocessing function for inference
|
| 36 |
+
def preprocess_image(image):
|
| 37 |
+
bird_box = detect_bird_region(image)
|
| 38 |
+
if bird_box is not None:
|
| 39 |
+
image = image.crop(bird_box) # Crop to bird region
|
| 40 |
+
|
| 41 |
+
# Apply validation transformations
|
| 42 |
+
val_transform = transforms.Compose([
|
| 43 |
+
transforms.Resize((229, 229)),
|
| 44 |
+
ToTensor(),
|
| 45 |
+
transforms.ConvertImageDtype(torch.float32),
|
| 46 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 47 |
+
])
|
| 48 |
+
return val_transform(image).unsqueeze(0) # Add batch dimension
|
| 49 |
+
|
| 50 |
+
# Prediction function
|
| 51 |
+
def predict(image):
|
| 52 |
+
# Preprocess the image
|
| 53 |
+
image = preprocess_image(image)
|
| 54 |
+
|
| 55 |
+
# Perform prediction
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
outputs = model(image)
|
| 58 |
+
predicted_class = torch.argmax(outputs, dim=1).item()
|
| 59 |
+
|
| 60 |
+
# Map the predicted class to bird names using bird_folders
|
| 61 |
+
bird_folders = {
|
| 62 |
+
0: "019.Gray_Catbird",
|
| 63 |
+
1: "025.Pelagic_Cormorant",
|
| 64 |
+
2: "026.Bronzed_Cowbird",
|
| 65 |
+
3: "029.American_Crow",
|
| 66 |
+
4: "039.Least_Flycatcher",
|
| 67 |
+
5: "073.Blue_Jay",
|
| 68 |
+
6: "085.Horned_Lark",
|
| 69 |
+
7: "099.Ovenbird",
|
| 70 |
+
8: "104.American_Pipit",
|
| 71 |
+
9: "119.Field_Sparrow",
|
| 72 |
+
10: "127.Savannah_Sparrow",
|
| 73 |
+
11: "129.Song_Sparrow",
|
| 74 |
+
12: "135.Bank_Swallow",
|
| 75 |
+
13: "137.Cliff_Swallow",
|
| 76 |
+
14: "138.Tree_Swallow",
|
| 77 |
+
15: "142.Black_Tern",
|
| 78 |
+
16: "143.Caspian_Tern",
|
| 79 |
+
17: "144.Common_Tern",
|
| 80 |
+
18: "167.Hooded_Warbler",
|
| 81 |
+
19: "176.Prairie_Warbler",
|
| 82 |
+
20: "177.Prothonotary_Warbler",
|
| 83 |
+
21: "179.Tennessee_Warbler",
|
| 84 |
+
22: "182.Yellow_Warbler",
|
| 85 |
+
23: "183.Northern_Waterthrush",
|
| 86 |
+
24: "185.Bohemian_Waxwing",
|
| 87 |
+
25: "186.Cedar_Waxwing",
|
| 88 |
+
26: "188.Pileated_Woodpecker",
|
| 89 |
+
27: "192.Downy_Woodpecker",
|
| 90 |
+
28: "195.Carolina_Wren",
|
| 91 |
+
29: "199.Winter_Wren"
|
| 92 |
+
}
|
| 93 |
+
return bird_folders[predicted_class] # Return bird name as output
|
| 94 |
+
|
| 95 |
+
# Gradio Interface
|
| 96 |
+
interface = gr.Interface(
|
| 97 |
+
fn=predict,
|
| 98 |
+
inputs=gr.Image(type="pil"),
|
| 99 |
+
outputs="label" # Display class label as output
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Launch Gradio App
|
| 103 |
+
interface.launch()
|