import json import birder import numpy as np from birder.inference.classification import infer_image from huggingface_hub import HfApi import gradio as gr def get_birder_classification_models(): api = HfApi() models = api.list_models(author="birder-project", tags="image-classification") return [model.modelId.split("/")[-1] for model in models] def get_selected_models(): return [ "convnext_v2_tiny_intermediate-il-common", "mvit_v2_t_il-all", "regnet_y_8g_intermediate-eu-common", "hiera_abswin_base_mim-intermediate-eu-common", "focalnet_b_lrf_intermediate-arabian-peninsula", "swin_transformer_v2_s_intermediate-arabian-peninsula", "rope_vit_reg4_b14_capi-inat21", "rope_vit_reg4_b14_capi-places365", ] def load_model_and_predict(image, model_name): try: if len(birder.list_pretrained_models(model_name)) == 0: model_name = birder.list_pretrained_models(model_name + "*")[0] (net, (class_to_idx, signature, rgb_stats, *_)) = birder.load_pretrained_model(model_name, inference=True) size = birder.get_size_from_signature(signature) transform = birder.classification_transform(size, rgb_stats) (out, _) = infer_image(net, image, transform) if model_name.endswith("-inat21") is True: with open("inat21-mapping.json", "r", encoding="utf-8") as handle: class_dict = json.load(handle) class_mapping = {k: int(v) for v, k in class_dict.items()} idx_to_class = dict(zip(class_to_idx.values(), class_to_idx.keys())) idx_to_class.update(dict(zip(class_mapping.values(), class_mapping.keys()))) class_to_idx = dict(zip(idx_to_class.values(), idx_to_class.keys())) idx_to_class = {v: k for k, v in class_to_idx.items()} topk_idx = np.argsort(out[0])[-3:][::-1] predictions = [(idx_to_class[idx], float(out[0][idx])) for idx in topk_idx] return predictions except Exception as e: return [(f"Error: {str(e)}", 0.0)] def predict(image, model_name): predictions = load_model_and_predict(image, model_name) return {f"{class_name} ({conf:.2%})": conf for class_name, conf in predictions} def create_interface(): models = get_selected_models() examples = [ ["Common myna.jpeg", "mvit_v2_t_il-all"], ["Eurasian hoopoe.jpeg", "hiera_abswin_base_mim-intermediate-eu-common"], ["Grey heron.jpeg", "swin_transformer_v2_s_intermediate-arabian-peninsula"], ] # Create interface iface = gr.Interface( analytics_enabled=False, deep_link=False, fn=predict, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Dropdown( choices=models, label="Select Model", value=models[0] if models else None, ), ], outputs=gr.Label(num_top_classes=3), examples=examples, title="Birder Image Classification", description="Select a model and upload an image or use one of the examples to get bird species predictions.", ) return iface # Launch the app if __name__ == "__main__": demo = create_interface() demo.launch()