hassonofer's picture
Update files
ebe495d
import json
import birder
import numpy as np
from birder.inference.classification import infer_image
from huggingface_hub import HfApi
import gradio as gr
def get_birder_classification_models():
api = HfApi()
models = api.list_models(author="birder-project", tags="image-classification")
return [model.modelId.split("/")[-1] for model in models]
def get_selected_models():
return [
"convnext_v2_tiny_intermediate-il-common",
"mvit_v2_t_il-all",
"regnet_y_8g_intermediate-eu-common",
"hiera_abswin_base_mim-intermediate-eu-common",
"focalnet_b_lrf_intermediate-arabian-peninsula",
"swin_transformer_v2_s_intermediate-arabian-peninsula",
"rope_vit_reg4_b14_capi-inat21",
"rope_vit_reg4_b14_capi-places365",
]
def load_model_and_predict(image, model_name):
try:
if len(birder.list_pretrained_models(model_name)) == 0:
model_name = birder.list_pretrained_models(model_name + "*")[0]
(net, (class_to_idx, signature, rgb_stats, *_)) = birder.load_pretrained_model(model_name, inference=True)
size = birder.get_size_from_signature(signature)
transform = birder.classification_transform(size, rgb_stats)
(out, _) = infer_image(net, image, transform)
if model_name.endswith("-inat21") is True:
with open("inat21-mapping.json", "r", encoding="utf-8") as handle:
class_dict = json.load(handle)
class_mapping = {k: int(v) for v, k in class_dict.items()}
idx_to_class = dict(zip(class_to_idx.values(), class_to_idx.keys()))
idx_to_class.update(dict(zip(class_mapping.values(), class_mapping.keys())))
class_to_idx = dict(zip(idx_to_class.values(), idx_to_class.keys()))
idx_to_class = {v: k for k, v in class_to_idx.items()}
topk_idx = np.argsort(out[0])[-3:][::-1]
predictions = [(idx_to_class[idx], float(out[0][idx])) for idx in topk_idx]
return predictions
except Exception as e:
return [(f"Error: {str(e)}", 0.0)]
def predict(image, model_name):
predictions = load_model_and_predict(image, model_name)
return {f"{class_name} ({conf:.2%})": conf for class_name, conf in predictions}
def create_interface():
models = get_selected_models()
examples = [
["Common myna.jpeg", "mvit_v2_t_il-all"],
["Eurasian hoopoe.jpeg", "hiera_abswin_base_mim-intermediate-eu-common"],
["Grey heron.jpeg", "swin_transformer_v2_s_intermediate-arabian-peninsula"],
]
# Create interface
iface = gr.Interface(
analytics_enabled=False,
deep_link=False,
fn=predict,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Dropdown(
choices=models,
label="Select Model",
value=models[0] if models else None,
),
],
outputs=gr.Label(num_top_classes=3),
examples=examples,
title="Birder Image Classification",
description="Select a model and upload an image or use one of the examples to get bird species predictions.",
)
return iface
# Launch the app
if __name__ == "__main__":
demo = create_interface()
demo.launch()