davanstrien HF Staff commited on
Commit
de33a84
·
1 Parent(s): d5321e7

check valid hub id

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -10,8 +10,7 @@ import pandas as pd
10
  import os
11
  import backoff
12
  from functools import lru_cache
13
-
14
- import os
15
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
16
 
17
 
@@ -65,15 +64,20 @@ def return_random_sample(k=27):
65
  images = dataset[sample]["image"]
66
  return [resize_image(image).convert("RGB") for image in images]
67
 
 
 
 
 
68
 
69
  def predict_subset(model_id, token):
70
  API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
71
  headers = {"Authorization": f"Bearer {token}"}
72
-
 
 
73
  @backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
74
  def _query(url):
75
  r = requests.post(API_URL, headers=headers, data=url)
76
- print(r)
77
  return r
78
 
79
  @lru_cache(maxsize=1000)
 
10
  import os
11
  import backoff
12
  from functools import lru_cache
13
+ from huggingface_hub import list_models, ModelFilter
 
14
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
15
 
16
 
 
64
  images = dataset[sample]["image"]
65
  return [resize_image(image).convert("RGB") for image in images]
66
 
67
+ @lru_cache()
68
+ def get_valid_hub_image_classification_model_ids():
69
+ models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
70
+ return {model.id for model in models}
71
 
72
  def predict_subset(model_id, token):
73
  API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
74
  headers = {"Authorization": f"Bearer {token}"}
75
+ valid_model_ids = get_valid_hub_image_classification_model_ids()
76
+ if model_id not in valid_model_ids:
77
+ gr.Error(f"model_id {model_id} is not a valid image classification model id")
78
  @backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
79
  def _query(url):
80
  r = requests.post(API_URL, headers=headers, data=url)
 
81
  return r
82
 
83
  @lru_cache(maxsize=1000)