import json import gradio as gr import pandas as pd import plotly.express as px PIPELINE_TAGS = [ 'text-generation', 'text-to-image', 'text-classification', 'text2text-generation', 'audio-to-audio', 'feature-extraction', 'image-classification', 'translation', 'reinforcement-learning', 'fill-mask', 'text-to-speech', 'automatic-speech-recognition', 'image-text-to-text', 'token-classification', 'sentence-similarity', 'question-answering', 'image-feature-extraction', 'summarization', 'zero-shot-image-classification', 'object-detection', 'image-segmentation', 'image-to-image', 'image-to-text', 'audio-classification', 'visual-question-answering', 'text-to-video', 'zero-shot-classification', 'depth-estimation', 'text-ranking', 'image-to-video', 'multiple-choice', 'unconditional-image-generation', 'video-classification', 'text-to-audio', 'time-series-forecasting', 'any-to-any', 'video-text-to-text', 'table-question-answering', ] def is_audio_speech(repo_dct): res = (repo_dct.get("pipeline_tag", None) and "audio" in repo_dct.get("pipeline_tag", "").lower()) or \ (repo_dct.get("pipeline_tag", None) and "speech" in repo_dct.get("pipeline_tag", "").lower()) or \ (repo_dct.get("tags", None) and any("audio" in tag.lower() for tag in repo_dct.get("tags", []))) or \ (repo_dct.get("tags", None) and any("speech" in tag.lower() for tag in repo_dct.get("tags", []))) return res def is_music(repo_dct): res = (repo_dct.get("tags", None) and any("music" in tag.lower() for tag in repo_dct.get("tags", []))) return res def is_robotics(repo_dct): res = (repo_dct.get("tags", None) and any("robot" in tag.lower() for tag in repo_dct.get("tags", []))) return res def is_biomed(repo_dct): res = (repo_dct.get("tags", None) and any("bio" in tag.lower() for tag in repo_dct.get("tags", []))) or \ (repo_dct.get("tags", None) and any("medic" in tag.lower() for tag in repo_dct.get("tags", []))) return res def is_timeseries(repo_dct): res = (repo_dct.get("tags", None) and any("series" in tag.lower() for tag in repo_dct.get("tags", []))) return res def is_science(repo_dct): res = (repo_dct.get("tags", None) and any("science" in tag.lower() and not "bigscience" in tag for tag in repo_dct.get("tags", []))) return res def is_video(repo_dct): res = (repo_dct.get("tags", None) and any("video" in tag.lower() for tag in repo_dct.get("tags", []))) return res def is_image(repo_dct): res = (repo_dct.get("tags", None) and any("image" in tag.lower() for tag in repo_dct.get("tags", []))) return res def is_text(repo_dct): res = (repo_dct.get("tags", None) and any("text" in tag.lower() for tag in repo_dct.get("tags", []))) return res TAG_FILTER_FUNCS = { "Audio & Speech": is_audio_speech, "Time series": is_timeseries, "Robotics": is_robotics, "Music": is_music, "Video": is_video, "Images": is_image, "Text": is_text, "Biomedical": is_biomed, "Sciences": is_science, } def make_org_stats(repo_type, count_by, org_stats, top_k=20, filter_func=None): assert count_by in ["likes", "downloads", "downloads_all"] assert repo_type in ["all", "datasets", "models"] repos = ["datasets", "models"] if repo_type == "all" else [repo_type] if filter_func is None: filter_func = lambda x: True sorted_stats = sorted( [( author, sum(dct[count_by] for dct in author_dct[repo] if filter_func(dct)) ) for repo in repos for author, author_dct in org_stats.items()], key=lambda x:x[1], reverse=True, ) res = sorted_stats[:top_k] + [("Others...", sum(st for auth, st in sorted_stats[top_k:]))] total_st = sum(st for o, st in res) res_plot_df = [] for org, st in res: if org == "Others...": res_plot_df += [("Others...", "other", st * 100 / total_st)] else: for repo in repos: for dct in org_stats[org][repo]: if filter_func(dct): res_plot_df += [(org, dct["id"], dct[count_by] * 100 / total_st)] return ([(o, 100 * st / total_st) for o, st in res if st > 0], res_plot_df) def make_figure(count_by, repo_type, org_stats, tag_filter=None, pipeline_filter=None): assert count_by in ["downloads", "likes", "downloads_all"] assert repo_type in ["all", "models", "datasets"] assert tag_filter is None or pipeline_filter is None filter_func = None if tag_filter: filter_func = TAG_FILTER_FUNCS[tag_filter] if pipeline_filter: filter_func = lambda dct: dct.get("pipeline_tag", None) and dct.get("pipeline_tag", "") == pipeline_filter _, res_plot_df = make_org_stats(repo_type, count_by, org_stats, top_k=25, filter_func=filter_func) df = pd.DataFrame( dict( organizations=[o for o, _, _ in res_plot_df], repo=[r for _, r, _ in res_plot_df], stats=[s for _, _, s in res_plot_df], ) ) df[repo_type] = repo_type # in order to have a single root node fig = px.treemap(df, path=[repo_type, 'organizations', 'repo'], values='stats') fig.update_layout( treemapcolorway = ["pink" for _ in range(len(res_plot_df))], margin = dict(t=50, l=25, r=25, b=25) ) return fig with gr.Blocks() as demo: org_stats_data = gr.State(value=None) # To store loaded data with gr.Row(): gr.Markdown(""" ## Hugging Face Organization Stats This app shows how different organizations are contributing to different aspects of the open AI ecosystem. Use the dropdowns on the left to select repository types, metrics, and optionally tags representing topics or modalities of interest. """) with gr.Row(): with gr.Column(scale=1): repo_type_dropdown = gr.Dropdown( label="Repository Type", choices=["all", "models", "datasets"], value="all" ) count_by_dropdown = gr.Dropdown( label="Metric", choices=["downloads", "likes", "downloads_all"], value="downloads" ) filter_choice_radio = gr.Radio( label="Filter by", choices=["None", "Tag Filter", "Pipeline Filter"], value="None" ) tag_filter_dropdown = gr.Dropdown( label="Select Tag", choices=list(TAG_FILTER_FUNCS.keys()), value=None, visible=False ) pipeline_filter_dropdown = gr.Dropdown( label="Select Pipeline Tag", choices=PIPELINE_TAGS, value=None, visible=False ) generate_plot_button = gr.Button("Generate Plot") with gr.Column(scale=3): plot_output = gr.Plot() def generate_plot_on_click(repo_type, count_by, filter_choice, tag_filter, pipeline_filter, data): # Print the current state of the input variables print(f"Generating plot with the following inputs:") print(f" Repository Type: {repo_type}") print(f" Metric (Count By): {count_by}") print(f" Filter Choice: {filter_choice}") if filter_choice == "Tag Filter": print(f" Tag Filter: {tag_filter}") elif filter_choice == "Pipeline Filter": print(f" Pipeline Filter: {pipeline_filter}") if data is None: print("Error: Data not loaded yet.") return None selected_tag_filter = None selected_pipeline_filter = None if filter_choice == "Tag Filter": selected_tag_filter = tag_filter elif filter_choice == "Pipeline Filter": selected_pipeline_filter = pipeline_filter fig = make_figure( count_by=count_by, repo_type=repo_type, org_stats=data, tag_filter=selected_tag_filter, pipeline_filter=selected_pipeline_filter ) return fig def update_filter_visibility(filter_choice): if filter_choice == "Tag Filter": return gr.update(visible=True), gr.update(visible=False) elif filter_choice == "Pipeline Filter": return gr.update(visible=False), gr.update(visible=True) else: # "None" return gr.update(visible=False), gr.update(visible=False) filter_choice_radio.change( fn=update_filter_visibility, inputs=[filter_choice_radio], outputs=[tag_filter_dropdown, pipeline_filter_dropdown] ) # Load data once at startup def load_org_data(): print("Loading organization statistics data...") loaded_org_stats = json.load(open("org_to_artifacts_2l_stats.json")) print("Data loaded successfully.") return loaded_org_stats demo.load( fn=load_org_data, inputs=[], # No inputs needed to just load data outputs=[org_stats_data] # Only output to the state ) # Button click event to generate plot generate_plot_button.click( fn=generate_plot_on_click, inputs=[ repo_type_dropdown, count_by_dropdown, filter_choice_radio, tag_filter_dropdown, pipeline_filter_dropdown, org_stats_data ], outputs=[plot_output] ) if __name__ == "__main__": # org_stats = json.load(open("org_to_artifacts_2l_stats.json")) # Data loading handled by demo.load demo.launch()