|
import json |
|
import gradio as gr |
|
import pandas as pd |
|
import plotly.express as px |
|
|
|
PIPELINE_TAGS = [ |
|
'text-generation', |
|
'text-to-image', |
|
'text-classification', |
|
'text2text-generation', |
|
'audio-to-audio', |
|
'feature-extraction', |
|
'image-classification', |
|
'translation', |
|
'reinforcement-learning', |
|
'fill-mask', |
|
'text-to-speech', |
|
'automatic-speech-recognition', |
|
'image-text-to-text', |
|
'token-classification', |
|
'sentence-similarity', |
|
'question-answering', |
|
'image-feature-extraction', |
|
'summarization', |
|
'zero-shot-image-classification', |
|
'object-detection', |
|
'image-segmentation', |
|
'image-to-image', |
|
'image-to-text', |
|
'audio-classification', |
|
'visual-question-answering', |
|
'text-to-video', |
|
'zero-shot-classification', |
|
'depth-estimation', |
|
'text-ranking', |
|
'image-to-video', |
|
'multiple-choice', |
|
'unconditional-image-generation', |
|
'video-classification', |
|
'text-to-audio', |
|
'time-series-forecasting', |
|
'any-to-any', |
|
'video-text-to-text', |
|
'table-question-answering', |
|
] |
|
|
|
def is_audio_speech(repo_dct): |
|
res = (repo_dct.get("pipeline_tag", None) and "audio" in repo_dct.get("pipeline_tag", "").lower()) or \ |
|
(repo_dct.get("pipeline_tag", None) and "speech" in repo_dct.get("pipeline_tag", "").lower()) or \ |
|
(repo_dct.get("tags", None) and any("audio" in tag.lower() for tag in repo_dct.get("tags", []))) or \ |
|
(repo_dct.get("tags", None) and any("speech" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_music(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("music" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_robotics(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("robot" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_biomed(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("bio" in tag.lower() for tag in repo_dct.get("tags", []))) or \ |
|
(repo_dct.get("tags", None) and any("medic" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_timeseries(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("series" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_science(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("science" in tag.lower() and not "bigscience" in tag for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_video(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("video" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_image(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("image" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
def is_text(repo_dct): |
|
res = (repo_dct.get("tags", None) and any("text" in tag.lower() for tag in repo_dct.get("tags", []))) |
|
return res |
|
|
|
TAG_FILTER_FUNCS = { |
|
"Audio & Speech": is_audio_speech, |
|
"Time series": is_timeseries, |
|
"Robotics": is_robotics, |
|
"Music": is_music, |
|
"Video": is_video, |
|
"Images": is_image, |
|
"Text": is_text, |
|
"Biomedical": is_biomed, |
|
"Sciences": is_science, |
|
} |
|
|
|
def make_org_stats(repo_type, count_by, org_stats, top_k=20, filter_func=None): |
|
assert count_by in ["likes", "downloads", "downloads_all"] |
|
assert repo_type in ["all", "datasets", "models"] |
|
repos = ["datasets", "models"] if repo_type == "all" else [repo_type] |
|
if filter_func is None: |
|
filter_func = lambda x: True |
|
sorted_stats = sorted( |
|
[( |
|
author, |
|
sum(dct[count_by] for dct in author_dct[repo] if filter_func(dct)) |
|
) for repo in repos for author, author_dct in org_stats.items()], |
|
key=lambda x:x[1], |
|
reverse=True, |
|
) |
|
res = sorted_stats[:top_k] + [("Others...", sum(st for auth, st in sorted_stats[top_k:]))] |
|
total_st = sum(st for o, st in res) |
|
res_plot_df = [] |
|
for org, st in res: |
|
if org == "Others...": |
|
res_plot_df += [("Others...", "other", st * 100 / total_st)] |
|
else: |
|
for repo in repos: |
|
for dct in org_stats[org][repo]: |
|
if filter_func(dct): |
|
res_plot_df += [(org, dct["id"], dct[count_by] * 100 / total_st)] |
|
return ([(o, 100 * st / total_st) for o, st in res if st > 0], res_plot_df) |
|
|
|
def make_figure(count_by, repo_type, org_stats, tag_filter=None, pipeline_filter=None): |
|
assert count_by in ["downloads", "likes", "downloads_all"] |
|
assert repo_type in ["all", "models", "datasets"] |
|
assert tag_filter is None or pipeline_filter is None |
|
filter_func = None |
|
if tag_filter: |
|
filter_func = TAG_FILTER_FUNCS[tag_filter] |
|
if pipeline_filter: |
|
filter_func = lambda dct: dct.get("pipeline_tag", None) and dct.get("pipeline_tag", "") == pipeline_filter |
|
_, res_plot_df = make_org_stats(repo_type, count_by, org_stats, top_k=25, filter_func=filter_func) |
|
df = pd.DataFrame( |
|
dict( |
|
organizations=[o for o, _, _ in res_plot_df], |
|
repo=[r for _, r, _ in res_plot_df], |
|
stats=[s for _, _, s in res_plot_df], |
|
) |
|
) |
|
df[repo_type] = repo_type |
|
fig = px.treemap(df, path=[repo_type, 'organizations', 'repo'], values='stats') |
|
fig.update_layout( |
|
treemapcolorway = ["pink" for _ in range(len(res_plot_df))], |
|
margin = dict(t=50, l=25, r=25, b=25) |
|
) |
|
return fig |
|
|
|
|
|
with gr.Blocks() as demo: |
|
org_stats_data = gr.State(value=None) |
|
|
|
with gr.Row(): |
|
gr.Markdown(""" |
|
## Hugging Face Organization Stats |
|
|
|
This app shows how different organizations are contributing to different aspects of the open AI ecosystem. |
|
Use the dropdowns on the left to select repository types, metrics, and optionally tags representing topics or modalities of interest. |
|
""") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
repo_type_dropdown = gr.Dropdown( |
|
label="Repository Type", |
|
choices=["all", "models", "datasets"], |
|
value="all" |
|
) |
|
count_by_dropdown = gr.Dropdown( |
|
label="Metric", |
|
choices=["downloads", "likes", "downloads_all"], |
|
value="downloads" |
|
) |
|
|
|
filter_choice_radio = gr.Radio( |
|
label="Filter by", |
|
choices=["None", "Tag Filter", "Pipeline Filter"], |
|
value="None" |
|
) |
|
|
|
tag_filter_dropdown = gr.Dropdown( |
|
label="Select Tag", |
|
choices=list(TAG_FILTER_FUNCS.keys()), |
|
value=None, |
|
visible=False |
|
) |
|
pipeline_filter_dropdown = gr.Dropdown( |
|
label="Select Pipeline Tag", |
|
choices=PIPELINE_TAGS, |
|
value=None, |
|
visible=False |
|
) |
|
|
|
generate_plot_button = gr.Button("Generate Plot") |
|
|
|
with gr.Column(scale=3): |
|
plot_output = gr.Plot() |
|
|
|
def generate_plot_on_click(repo_type, count_by, filter_choice, tag_filter, pipeline_filter, data): |
|
|
|
print(f"Generating plot with the following inputs:") |
|
print(f" Repository Type: {repo_type}") |
|
print(f" Metric (Count By): {count_by}") |
|
print(f" Filter Choice: {filter_choice}") |
|
if filter_choice == "Tag Filter": |
|
print(f" Tag Filter: {tag_filter}") |
|
elif filter_choice == "Pipeline Filter": |
|
print(f" Pipeline Filter: {pipeline_filter}") |
|
|
|
if data is None: |
|
print("Error: Data not loaded yet.") |
|
return None |
|
|
|
selected_tag_filter = None |
|
selected_pipeline_filter = None |
|
|
|
if filter_choice == "Tag Filter": |
|
selected_tag_filter = tag_filter |
|
elif filter_choice == "Pipeline Filter": |
|
selected_pipeline_filter = pipeline_filter |
|
|
|
fig = make_figure( |
|
count_by=count_by, |
|
repo_type=repo_type, |
|
org_stats=data, |
|
tag_filter=selected_tag_filter, |
|
pipeline_filter=selected_pipeline_filter |
|
) |
|
return fig |
|
|
|
def update_filter_visibility(filter_choice): |
|
if filter_choice == "Tag Filter": |
|
return gr.update(visible=True), gr.update(visible=False) |
|
elif filter_choice == "Pipeline Filter": |
|
return gr.update(visible=False), gr.update(visible=True) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=False) |
|
|
|
filter_choice_radio.change( |
|
fn=update_filter_visibility, |
|
inputs=[filter_choice_radio], |
|
outputs=[tag_filter_dropdown, pipeline_filter_dropdown] |
|
) |
|
|
|
|
|
def load_org_data(): |
|
print("Loading organization statistics data...") |
|
loaded_org_stats = json.load(open("org_to_artifacts_2l_stats.json")) |
|
print("Data loaded successfully.") |
|
return loaded_org_stats |
|
|
|
demo.load( |
|
fn=load_org_data, |
|
inputs=[], |
|
outputs=[org_stats_data] |
|
) |
|
|
|
|
|
generate_plot_button.click( |
|
fn=generate_plot_on_click, |
|
inputs=[ |
|
repo_type_dropdown, |
|
count_by_dropdown, |
|
filter_choice_radio, |
|
tag_filter_dropdown, |
|
pipeline_filter_dropdown, |
|
org_stats_data |
|
], |
|
outputs=[plot_output] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo.launch() |