File size: 3,295 Bytes
ebe495d
 
3757755
 
 
 
 
 
 
 
 
 
 
 
 
 
ebe495d
 
 
 
 
 
 
 
 
 
 
 
 
3757755
 
67162f7
 
 
2d91bdc
3757755
 
 
 
ebe495d
 
 
 
 
 
 
 
 
3757755
 
 
 
 
ebe495d
3757755
 
 
 
 
 
 
 
 
 
ebe495d
3757755
02be6f1
 
ebe495d
 
3757755
 
 
 
 
ebe495d
3757755
 
 
 
 
 
 
 
 
 
02be6f1
3757755
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import json

import birder
import numpy as np
from birder.inference.classification import infer_image
from huggingface_hub import HfApi

import gradio as gr


def get_birder_classification_models():
    api = HfApi()
    models = api.list_models(author="birder-project", tags="image-classification")
    return [model.modelId.split("/")[-1] for model in models]


def get_selected_models():
    return [
        "convnext_v2_tiny_intermediate-il-common",
        "mvit_v2_t_il-all",
        "regnet_y_8g_intermediate-eu-common",
        "hiera_abswin_base_mim-intermediate-eu-common",
        "focalnet_b_lrf_intermediate-arabian-peninsula",
        "swin_transformer_v2_s_intermediate-arabian-peninsula",
        "rope_vit_reg4_b14_capi-inat21",
        "rope_vit_reg4_b14_capi-places365",
    ]


def load_model_and_predict(image, model_name):
    try:
        if len(birder.list_pretrained_models(model_name)) == 0:
            model_name = birder.list_pretrained_models(model_name + "*")[0]

        (net, (class_to_idx, signature, rgb_stats, *_)) = birder.load_pretrained_model(model_name, inference=True)
        size = birder.get_size_from_signature(signature)
        transform = birder.classification_transform(size, rgb_stats)
        (out, _) = infer_image(net, image, transform)

        if model_name.endswith("-inat21") is True:
            with open("inat21-mapping.json", "r", encoding="utf-8") as handle:
                class_dict = json.load(handle)

            class_mapping = {k: int(v) for v, k in class_dict.items()}
            idx_to_class = dict(zip(class_to_idx.values(), class_to_idx.keys()))
            idx_to_class.update(dict(zip(class_mapping.values(), class_mapping.keys())))
            class_to_idx = dict(zip(idx_to_class.values(), idx_to_class.keys()))

        idx_to_class = {v: k for k, v in class_to_idx.items()}
        topk_idx = np.argsort(out[0])[-3:][::-1]
        predictions = [(idx_to_class[idx], float(out[0][idx])) for idx in topk_idx]

        return predictions

    except Exception as e:
        return [(f"Error: {str(e)}", 0.0)]


def predict(image, model_name):
    predictions = load_model_and_predict(image, model_name)
    return {f"{class_name} ({conf:.2%})": conf for class_name, conf in predictions}


def create_interface():
    models = get_selected_models()

    examples = [
        ["Common myna.jpeg", "mvit_v2_t_il-all"],
        ["Eurasian hoopoe.jpeg", "hiera_abswin_base_mim-intermediate-eu-common"],
        ["Grey heron.jpeg", "swin_transformer_v2_s_intermediate-arabian-peninsula"],
    ]

    # Create interface
    iface = gr.Interface(
        analytics_enabled=False,
        deep_link=False,
        fn=predict,
        inputs=[
            gr.Image(type="pil", label="Input Image"),
            gr.Dropdown(
                choices=models,
                label="Select Model",
                value=models[0] if models else None,
            ),
        ],
        outputs=gr.Label(num_top_classes=3),
        examples=examples,
        title="Birder Image Classification",
        description="Select a model and upload an image or use one of the examples to get bird species predictions.",
    )

    return iface


# Launch the app
if __name__ == "__main__":
    demo = create_interface()
    demo.launch()