Spaces:
Runtime error
Runtime error
Commit
·
de33a84
1
Parent(s):
d5321e7
check valid hub id
Browse files
app.py
CHANGED
@@ -10,8 +10,7 @@ import pandas as pd
|
|
10 |
import os
|
11 |
import backoff
|
12 |
from functools import lru_cache
|
13 |
-
|
14 |
-
import os
|
15 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
16 |
|
17 |
|
@@ -65,15 +64,20 @@ def return_random_sample(k=27):
|
|
65 |
images = dataset[sample]["image"]
|
66 |
return [resize_image(image).convert("RGB") for image in images]
|
67 |
|
|
|
|
|
|
|
|
|
68 |
|
69 |
def predict_subset(model_id, token):
|
70 |
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
|
71 |
headers = {"Authorization": f"Bearer {token}"}
|
72 |
-
|
|
|
|
|
73 |
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
|
74 |
def _query(url):
|
75 |
r = requests.post(API_URL, headers=headers, data=url)
|
76 |
-
print(r)
|
77 |
return r
|
78 |
|
79 |
@lru_cache(maxsize=1000)
|
|
|
10 |
import os
|
11 |
import backoff
|
12 |
from functools import lru_cache
|
13 |
+
from huggingface_hub import list_models, ModelFilter
|
|
|
14 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
15 |
|
16 |
|
|
|
64 |
images = dataset[sample]["image"]
|
65 |
return [resize_image(image).convert("RGB") for image in images]
|
66 |
|
67 |
+
@lru_cache()
|
68 |
+
def get_valid_hub_image_classification_model_ids():
|
69 |
+
models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
|
70 |
+
return {model.id for model in models}
|
71 |
|
72 |
def predict_subset(model_id, token):
|
73 |
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
|
74 |
headers = {"Authorization": f"Bearer {token}"}
|
75 |
+
valid_model_ids = get_valid_hub_image_classification_model_ids()
|
76 |
+
if model_id not in valid_model_ids:
|
77 |
+
gr.Error(f"model_id {model_id} is not a valid image classification model id")
|
78 |
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
|
79 |
def _query(url):
|
80 |
r = requests.post(API_URL, headers=headers, data=url)
|
|
|
81 |
return r
|
82 |
|
83 |
@lru_cache(maxsize=1000)
|