testwarm / hfsearch.py
John6666's picture
Upload 5 files
2303139 verified
raw
history blame
7.87 kB
import spaces
import gradio as gr
from huggingface_hub import HfApi
import gc
class Labels():
VALID_DTYPE = ["str", "number", "bool", "date", "markdown"]
def __init__(self):
self.types = {}
self.orders = {}
def set(self, label: str, type: str="str", order: int=255):
if type not in self.VALID_DTYPE: raise Exception(f"Invalid data type: {type}")
self.types[label] = type
self.orders[label] = order
def get(self):
labels = list(self.types.keys())
labels.sort(key=lambda x: self.orders[x])
label_types = [self.types[s] for s in labels]
return labels, label_types
def get_null_value(self, type: str):
if type == "bool": return False
elif type == "number" or type == "date": return 0
else: return "None"
class HFSearchResult():
def __init__(self):
self.labels = Labels()
self.current_item = {}
self.current_show_item = {}
self.item_list = []
self.show_item_list = []
self.item_hide_flags = []
self.hide_item = []
self.filter_items = None
self.filters = None
gc.collect()
def reset(self):
self.__init__()
def set(self, data, label: str, type: str="str", order: int=255, show_data=None):
self.labels.set(label, type, order)
self.current_item[label] = data
if show_data is not None: self.current_show_item[label] = show_data
def next(self):
self.item_list.append(self.current_item.copy())
self.current_item = {}
self.show_item_list.append(self.current_show_item.copy())
self.current_show_item = {}
def get(self):
labels, label_types = self.labels.get()
df = [[item.get(l, self.labels.get_null_value(t)) for l, t in zip(labels, label_types)] for item in self.item_list]
return df, labels, label_types
def get_show(self):
labels, label_types = self.labels.get()
self._do_filter()
df = [[show_item.get(l, self.labels.get_null_value(t)) if l in show_item.keys() else item.get(l, self.labels.get_null_value(t)) for l, t in zip(labels, label_types) if l not in set(self.hide_item)] for item, show_item, is_hide in zip(self.item_list, self.show_item_list, self.item_hide_flags) if not is_hide]
show_label_types = [t for l, t in zip(labels, label_types) if l not in self.hide_item]
show_labels = [l for l in labels if l not in self.hide_item]
return df, show_labels, show_label_types
def set_hide(self, hide_item: list):
self.hide_item = hide_item
def set_filter(self, filter_item1: str, filter1: str):
if not filter_item1 and not filter1:
self.filter_items = None
self.filters = None
else:
self.filter_items = [filter_item1]
self.filters = [filter1]
def _do_filter(self):
if self.filters is None or self.filter_items is None:
self.item_hide_flags = [False] * len(self.item_list)
return
labels, label_types = self.labels.get()
types = dict(zip(labels, label_types))
flags = []
for item in self.item_list:
flag = False
for i, f in zip(self.filter_items, self.filters):
if i not in item.keys(): continue
t = types[i]
if item[i] == self.labels.get_null_value(t):
flag = True
break
if t in set(["str", "markdown"]):
if f in item[i]: flag = False
else:
flag = True
break
flags.append(flag)
self.item_hide_flags = flags
def get_gr_df(self):
df, labels, label_types = self.get_show()
return gr.update(type="array", value=df, headers=labels, datatype=label_types)
def get_gr_hide_item(self):
return gr.update(choices=self.labels.get()[0], value=[], visible=True)
def get_gr_filter_item(self, filter_item: str=""):
labels, label_types = self.labels.get()
choices = [s for s, t in zip(labels, label_types) if t in set(["str", "markdown"])]
if len(choices) == 0: choices = [""]
return gr.update(choices=choices, value=filter_item if filter_item else choices[0], visible=True)
def get_gr_filter(self, filter_item: str=""):
labels = self.labels.get()[0]
if not filter_item or filter_item not in set(labels): return gr.update(choices=[""], value="", visible=True)
d = {}
for item in self.item_list:
if filter_item not in item.keys(): continue
v = item[filter_item]
if v in d.keys(): d[v] += 1
else: d[v] = 1
return gr.update(choices=[""] + [t[0] for t in sorted(d.items(), key=lambda x : x[1])][:100], value="", visible=True)
def md_lb(s: str, count: int):
return "<br>".join([s[i:i+count] for i in range(0, len(s), count)])
# https://huggingface.co/docs/huggingface_hub/package_reference/hf_api
# https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.ModelInfo
@spaces.GPU
def search(sort: str, sort_method: str, filter: str, author: str, infer: str, gated: str, appr: list[str], limit: int, r: HFSearchResult):
try:
api = HfApi()
kwargs = {}
if filter: kwargs["filter"] = filter
if author: kwargs["author"] = author
if gated == "gated": kwargs["gated"] = True
elif gated == "non-gated": kwargs["gated"] = False
if infer != "all": kwargs["inference"] = infer
if sort_method == "descending order": kwargs["direction"] = -1
if limit > 0: kwargs["limit"] = limit
models = api.list_models(sort=sort, cardData=True, full=True, **kwargs)
r.reset()
i = 1
for model in models:
if model.gated is not None and model.gated and model.gated not in appr: continue
r.set(i, "No.", "number", 0)
r.set(model.id, "Model", "markdown", 2, f"[{md_lb(model.id, 48)}](https://hf.co/{model.id})")
if model.inference is not None: r.set(model.inference, "Status", "markdown", 4, md_lb(model.inference, 8))
#if infer != "all": r.set(infer, "Status", "markdown", 4)
if model.gated is not None: r.set(model.gated if model.gated else "off", "Gated", "str", 6)
#if gated != "all": r.set("on" if gated == "gated" else "off", "Gated", "str", 6)
if model.library_name is not None: r.set(model.library_name, "Library", "markdown", 10, md_lb(model.library_name, 12))
if model.pipeline_tag is not None: r.set(model.pipeline_tag, "Pipeline", "markdown", 11, md_lb(model.pipeline_tag, 15))
if model.last_modified is not None: r.set(model.last_modified, "LastMod.", "date", 12)
if model.likes is not None: r.set(model.likes, "Likes", "number", 13)
if model.downloads is not None: r.set(model.downloads, "DLs", "number", 14)
if model.downloads_all_time is not None: r.set(model.downloads_all_time, "AllDLs", "number", 15)
r.next()
i += 1
return r.get_gr_df(), r.get_gr_hide_item(), r
except Exception as e:
raise gr.Error(e)
def update_df(hide_item: list, filter_item1: str, filter1: str, r: HFSearchResult):
r.set_hide(hide_item)
r.set_filter(filter_item1, filter1)
return r.get_gr_df(), r
def update_filter(filter_item1: str, r: HFSearchResult):
return r.get_gr_filter_item(filter_item1), r.get_gr_filter(filter_item1), gr.update(visible=True), r