|
import spaces
|
|
import gradio as gr
|
|
from huggingface_hub import HfApi
|
|
import gc
|
|
|
|
class Labels():
|
|
VALID_DTYPE = ["str", "number", "bool", "date", "markdown"]
|
|
|
|
def __init__(self):
|
|
self.types = {}
|
|
self.orders = {}
|
|
|
|
def set(self, label: str, type: str="str", order: int=255):
|
|
if type not in self.VALID_DTYPE: raise Exception(f"Invalid data type: {type}")
|
|
self.types[label] = type
|
|
self.orders[label] = order
|
|
|
|
def get(self):
|
|
labels = list(self.types.keys())
|
|
labels.sort(key=lambda x: self.orders[x])
|
|
label_types = [self.types[s] for s in labels]
|
|
return labels, label_types
|
|
|
|
def get_null_value(self, type: str):
|
|
if type == "bool": return False
|
|
elif type == "number" or type == "date": return 0
|
|
else: return "None"
|
|
|
|
class HFSearchResult():
|
|
def __init__(self):
|
|
self.labels = Labels()
|
|
self.current_item = {}
|
|
self.current_show_item = {}
|
|
self.item_list = []
|
|
self.show_item_list = []
|
|
self.item_hide_flags = []
|
|
self.hide_item = []
|
|
self.filter_items = None
|
|
self.filters = None
|
|
gc.collect()
|
|
|
|
def reset(self):
|
|
self.__init__()
|
|
|
|
def set(self, data, label: str, type: str="str", order: int=255, show_data=None):
|
|
self.labels.set(label, type, order)
|
|
self.current_item[label] = data
|
|
if show_data is not None: self.current_show_item[label] = show_data
|
|
|
|
def next(self):
|
|
self.item_list.append(self.current_item.copy())
|
|
self.current_item = {}
|
|
self.show_item_list.append(self.current_show_item.copy())
|
|
self.current_show_item = {}
|
|
|
|
def get(self):
|
|
labels, label_types = self.labels.get()
|
|
df = [[item.get(l, self.labels.get_null_value(t)) for l, t in zip(labels, label_types)] for item in self.item_list]
|
|
return df, labels, label_types
|
|
|
|
def get_show(self):
|
|
labels, label_types = self.labels.get()
|
|
self._do_filter()
|
|
df = [[show_item.get(l, self.labels.get_null_value(t)) if l in show_item.keys() else item.get(l, self.labels.get_null_value(t)) for l, t in zip(labels, label_types) if l not in set(self.hide_item)] for item, show_item, is_hide in zip(self.item_list, self.show_item_list, self.item_hide_flags) if not is_hide]
|
|
show_label_types = [t for l, t in zip(labels, label_types) if l not in self.hide_item]
|
|
show_labels = [l for l in labels if l not in self.hide_item]
|
|
return df, show_labels, show_label_types
|
|
|
|
def set_hide(self, hide_item: list):
|
|
self.hide_item = hide_item
|
|
|
|
def set_filter(self, filter_item1: str, filter1: str):
|
|
if not filter_item1 and not filter1:
|
|
self.filter_items = None
|
|
self.filters = None
|
|
else:
|
|
self.filter_items = [filter_item1]
|
|
self.filters = [filter1]
|
|
|
|
def _do_filter(self):
|
|
if self.filters is None or self.filter_items is None:
|
|
self.item_hide_flags = [False] * len(self.item_list)
|
|
return
|
|
labels, label_types = self.labels.get()
|
|
types = dict(zip(labels, label_types))
|
|
flags = []
|
|
for item in self.item_list:
|
|
flag = False
|
|
for i, f in zip(self.filter_items, self.filters):
|
|
if i not in item.keys(): continue
|
|
t = types[i]
|
|
if item[i] == self.labels.get_null_value(t):
|
|
flag = True
|
|
break
|
|
if t in set(["str", "markdown"]):
|
|
if f in item[i]: flag = False
|
|
else:
|
|
flag = True
|
|
break
|
|
flags.append(flag)
|
|
self.item_hide_flags = flags
|
|
|
|
def get_gr_df(self):
|
|
df, labels, label_types = self.get_show()
|
|
return gr.update(type="array", value=df, headers=labels, datatype=label_types)
|
|
|
|
def get_gr_hide_item(self):
|
|
return gr.update(choices=self.labels.get()[0], value=[], visible=True)
|
|
|
|
def get_gr_filter_item(self, filter_item: str=""):
|
|
labels, label_types = self.labels.get()
|
|
choices = [s for s, t in zip(labels, label_types) if t in set(["str", "markdown"])]
|
|
if len(choices) == 0: choices = [""]
|
|
return gr.update(choices=choices, value=filter_item if filter_item else choices[0], visible=True)
|
|
|
|
def get_gr_filter(self, filter_item: str=""):
|
|
labels = self.labels.get()[0]
|
|
if not filter_item or filter_item not in set(labels): return gr.update(choices=[""], value="", visible=True)
|
|
d = {}
|
|
for item in self.item_list:
|
|
if filter_item not in item.keys(): continue
|
|
v = item[filter_item]
|
|
if v in d.keys(): d[v] += 1
|
|
else: d[v] = 1
|
|
return gr.update(choices=[""] + [t[0] for t in sorted(d.items(), key=lambda x : x[1])][:100], value="", visible=True)
|
|
|
|
def md_lb(s: str, count: int):
|
|
return "<br>".join([s[i:i+count] for i in range(0, len(s), count)])
|
|
|
|
|
|
|
|
@spaces.GPU
|
|
def search(sort: str, sort_method: str, filter: str, author: str, infer: str, gated: str, appr: list[str], limit: int, r: HFSearchResult):
|
|
try:
|
|
api = HfApi()
|
|
kwargs = {}
|
|
if filter: kwargs["filter"] = filter
|
|
if author: kwargs["author"] = author
|
|
if gated == "gated": kwargs["gated"] = True
|
|
elif gated == "non-gated": kwargs["gated"] = False
|
|
if infer != "all": kwargs["inference"] = infer
|
|
if sort_method == "descending order": kwargs["direction"] = -1
|
|
if limit > 0: kwargs["limit"] = limit
|
|
models = api.list_models(sort=sort, cardData=True, full=True, **kwargs)
|
|
r.reset()
|
|
i = 1
|
|
for model in models:
|
|
if model.gated is not None and model.gated and model.gated not in appr: continue
|
|
r.set(i, "No.", "number", 0)
|
|
r.set(model.id, "Model", "markdown", 2, f"[{md_lb(model.id, 48)}](https://hf.co/{model.id})")
|
|
if model.inference is not None: r.set(model.inference, "Status", "markdown", 4, md_lb(model.inference, 8))
|
|
|
|
if model.gated is not None: r.set(model.gated if model.gated else "off", "Gated", "str", 6)
|
|
|
|
if model.library_name is not None: r.set(model.library_name, "Library", "markdown", 10, md_lb(model.library_name, 12))
|
|
if model.pipeline_tag is not None: r.set(model.pipeline_tag, "Pipeline", "markdown", 11, md_lb(model.pipeline_tag, 15))
|
|
if model.last_modified is not None: r.set(model.last_modified, "LastMod.", "date", 12)
|
|
if model.likes is not None: r.set(model.likes, "Likes", "number", 13)
|
|
if model.downloads is not None: r.set(model.downloads, "DLs", "number", 14)
|
|
if model.downloads_all_time is not None: r.set(model.downloads_all_time, "AllDLs", "number", 15)
|
|
r.next()
|
|
i += 1
|
|
return r.get_gr_df(), r.get_gr_hide_item(), r
|
|
except Exception as e:
|
|
raise gr.Error(e)
|
|
|
|
def update_df(hide_item: list, filter_item1: str, filter1: str, r: HFSearchResult):
|
|
r.set_hide(hide_item)
|
|
r.set_filter(filter_item1, filter1)
|
|
return r.get_gr_df(), r
|
|
|
|
def update_filter(filter_item1: str, r: HFSearchResult):
|
|
return r.get_gr_filter_item(filter_item1), r.get_gr_filter(filter_item1), gr.update(visible=True), r
|
|
|