|
import json |
|
|
|
import birder |
|
import numpy as np |
|
from birder.inference.classification import infer_image |
|
from huggingface_hub import HfApi |
|
|
|
import gradio as gr |
|
|
|
|
|
def get_birder_classification_models(): |
|
api = HfApi() |
|
models = api.list_models(author="birder-project", tags="image-classification") |
|
return [model.modelId.split("/")[-1] for model in models] |
|
|
|
|
|
def get_selected_models(): |
|
return [ |
|
"convnext_v2_tiny_intermediate-il-common", |
|
"mvit_v2_t_il-all", |
|
"regnet_y_8g_intermediate-eu-common", |
|
"hiera_abswin_base_mim-intermediate-eu-common", |
|
"focalnet_b_lrf_intermediate-arabian-peninsula", |
|
"swin_transformer_v2_s_intermediate-arabian-peninsula", |
|
"rope_vit_reg4_b14_capi-inat21", |
|
"rope_vit_reg4_b14_capi-places365", |
|
] |
|
|
|
|
|
def load_model_and_predict(image, model_name): |
|
try: |
|
if len(birder.list_pretrained_models(model_name)) == 0: |
|
model_name = birder.list_pretrained_models(model_name + "*")[0] |
|
|
|
(net, (class_to_idx, signature, rgb_stats, *_)) = birder.load_pretrained_model(model_name, inference=True) |
|
size = birder.get_size_from_signature(signature) |
|
transform = birder.classification_transform(size, rgb_stats) |
|
(out, _) = infer_image(net, image, transform) |
|
|
|
if model_name.endswith("-inat21") is True: |
|
with open("inat21-mapping.json", "r", encoding="utf-8") as handle: |
|
class_dict = json.load(handle) |
|
|
|
class_mapping = {k: int(v) for v, k in class_dict.items()} |
|
idx_to_class = dict(zip(class_to_idx.values(), class_to_idx.keys())) |
|
idx_to_class.update(dict(zip(class_mapping.values(), class_mapping.keys()))) |
|
class_to_idx = dict(zip(idx_to_class.values(), idx_to_class.keys())) |
|
|
|
idx_to_class = {v: k for k, v in class_to_idx.items()} |
|
topk_idx = np.argsort(out[0])[-3:][::-1] |
|
predictions = [(idx_to_class[idx], float(out[0][idx])) for idx in topk_idx] |
|
|
|
return predictions |
|
|
|
except Exception as e: |
|
return [(f"Error: {str(e)}", 0.0)] |
|
|
|
|
|
def predict(image, model_name): |
|
predictions = load_model_and_predict(image, model_name) |
|
return {f"{class_name} ({conf:.2%})": conf for class_name, conf in predictions} |
|
|
|
|
|
def create_interface(): |
|
models = get_selected_models() |
|
|
|
examples = [ |
|
["Common myna.jpeg", "mvit_v2_t_il-all"], |
|
["Eurasian hoopoe.jpeg", "hiera_abswin_base_mim-intermediate-eu-common"], |
|
["Grey heron.jpeg", "swin_transformer_v2_s_intermediate-arabian-peninsula"], |
|
] |
|
|
|
|
|
iface = gr.Interface( |
|
analytics_enabled=False, |
|
deep_link=False, |
|
fn=predict, |
|
inputs=[ |
|
gr.Image(type="pil", label="Input Image"), |
|
gr.Dropdown( |
|
choices=models, |
|
label="Select Model", |
|
value=models[0] if models else None, |
|
), |
|
], |
|
outputs=gr.Label(num_top_classes=3), |
|
examples=examples, |
|
title="Birder Image Classification", |
|
description="Select a model and upload an image or use one of the examples to get bird species predictions.", |
|
) |
|
|
|
return iface |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch() |
|
|