hassonofer commited on
Commit
ebe495d
·
1 Parent(s): 8e3318c

Update files

Browse files
Files changed (2) hide show
  1. app.py +29 -4
  2. inat21-mapping.json +0 -0
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import birder
2
  import numpy as np
3
  from birder.inference.classification import infer_image
@@ -12,6 +14,19 @@ def get_birder_classification_models():
12
  return [model.modelId.split("/")[-1] for model in models]
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def load_model_and_predict(image, model_name):
16
  try:
17
  if len(birder.list_pretrained_models(model_name)) == 0:
@@ -22,11 +37,21 @@ def load_model_and_predict(image, model_name):
22
  transform = birder.classification_transform(size, rgb_stats)
23
  (out, _) = infer_image(net, image, transform)
24
 
 
 
 
 
 
 
 
 
 
25
  idx_to_class = {v: k for k, v in class_to_idx.items()}
26
  topk_idx = np.argsort(out[0])[-3:][::-1]
27
  predictions = [(idx_to_class[idx], float(out[0][idx])) for idx in topk_idx]
28
 
29
  return predictions
 
30
  except Exception as e:
31
  return [(f"Error: {str(e)}", 0.0)]
32
 
@@ -37,18 +62,18 @@ def predict(image, model_name):
37
 
38
 
39
  def create_interface():
40
- models = get_birder_classification_models()
41
 
42
  examples = [
43
  ["Common myna.jpeg", "mvit_v2_t_il-all"],
44
- ["Eurasian hoopoe.jpeg", "convnext_v2_tiny_intermediate-eu-common"],
45
- # ["Eurasian teal.jpeg", "iformer_s_arabian-peninsula"],
46
- ["Grey heron.jpeg", "iformer_s_arabian-peninsula"],
47
  ]
48
 
49
  # Create interface
50
  iface = gr.Interface(
51
  analytics_enabled=False,
 
52
  fn=predict,
53
  inputs=[
54
  gr.Image(type="pil", label="Input Image"),
 
1
+ import json
2
+
3
  import birder
4
  import numpy as np
5
  from birder.inference.classification import infer_image
 
14
  return [model.modelId.split("/")[-1] for model in models]
15
 
16
 
17
+ def get_selected_models():
18
+ return [
19
+ "convnext_v2_tiny_intermediate-il-common",
20
+ "mvit_v2_t_il-all",
21
+ "regnet_y_8g_intermediate-eu-common",
22
+ "hiera_abswin_base_mim-intermediate-eu-common",
23
+ "focalnet_b_lrf_intermediate-arabian-peninsula",
24
+ "swin_transformer_v2_s_intermediate-arabian-peninsula",
25
+ "rope_vit_reg4_b14_capi-inat21",
26
+ "rope_vit_reg4_b14_capi-places365",
27
+ ]
28
+
29
+
30
  def load_model_and_predict(image, model_name):
31
  try:
32
  if len(birder.list_pretrained_models(model_name)) == 0:
 
37
  transform = birder.classification_transform(size, rgb_stats)
38
  (out, _) = infer_image(net, image, transform)
39
 
40
+ if model_name.endswith("-inat21") is True:
41
+ with open("inat21-mapping.json", "r", encoding="utf-8") as handle:
42
+ class_dict = json.load(handle)
43
+
44
+ class_mapping = {k: int(v) for v, k in class_dict.items()}
45
+ idx_to_class = dict(zip(class_to_idx.values(), class_to_idx.keys()))
46
+ idx_to_class.update(dict(zip(class_mapping.values(), class_mapping.keys())))
47
+ class_to_idx = dict(zip(idx_to_class.values(), idx_to_class.keys()))
48
+
49
  idx_to_class = {v: k for k, v in class_to_idx.items()}
50
  topk_idx = np.argsort(out[0])[-3:][::-1]
51
  predictions = [(idx_to_class[idx], float(out[0][idx])) for idx in topk_idx]
52
 
53
  return predictions
54
+
55
  except Exception as e:
56
  return [(f"Error: {str(e)}", 0.0)]
57
 
 
62
 
63
 
64
  def create_interface():
65
+ models = get_selected_models()
66
 
67
  examples = [
68
  ["Common myna.jpeg", "mvit_v2_t_il-all"],
69
+ ["Eurasian hoopoe.jpeg", "hiera_abswin_base_mim-intermediate-eu-common"],
70
+ ["Grey heron.jpeg", "swin_transformer_v2_s_intermediate-arabian-peninsula"],
 
71
  ]
72
 
73
  # Create interface
74
  iface = gr.Interface(
75
  analytics_enabled=False,
76
+ deep_link=False,
77
  fn=predict,
78
  inputs=[
79
  gr.Image(type="pil", label="Input Image"),
inat21-mapping.json ADDED
The diff for this file is too large to render. See raw diff