|
import json |
|
import gradio as gr |
|
import pandas as pd |
|
import plotly.express as px |
|
import os |
|
import numpy as np |
|
import io |
|
import duckdb |
|
|
|
|
|
PIPELINE_TAGS = [ |
|
'text-generation', |
|
'text-to-image', |
|
'text-classification', |
|
'text2text-generation', |
|
'audio-to-audio', |
|
'feature-extraction', |
|
'image-classification', |
|
'translation', |
|
'reinforcement-learning', |
|
'fill-mask', |
|
'text-to-speech', |
|
'automatic-speech-recognition', |
|
'image-text-to-text', |
|
'token-classification', |
|
'sentence-similarity', |
|
'question-answering', |
|
'image-feature-extraction', |
|
'summarization', |
|
'zero-shot-image-classification', |
|
'object-detection', |
|
'image-segmentation', |
|
'image-to-image', |
|
'image-to-text', |
|
'audio-classification', |
|
'visual-question-answering', |
|
'text-to-video', |
|
'zero-shot-classification', |
|
'depth-estimation', |
|
'text-ranking', |
|
'image-to-video', |
|
'multiple-choice', |
|
'unconditional-image-generation', |
|
'video-classification', |
|
'text-to-audio', |
|
'time-series-forecasting', |
|
'any-to-any', |
|
'video-text-to-text', |
|
'table-question-answering', |
|
] |
|
|
|
|
|
MODEL_SIZE_RANGES = { |
|
"Small (<1GB)": (0, 1), |
|
"Medium (1-5GB)": (1, 5), |
|
"Large (5-20GB)": (5, 20), |
|
"X-Large (20-50GB)": (20, 50), |
|
"XX-Large (>50GB)": (50, float('inf')) |
|
} |
|
|
|
|
|
def is_audio_speech(row): |
|
|
|
return row['is_audio_speech'] |
|
|
|
def is_music(row): |
|
|
|
return row['has_music'] |
|
|
|
def is_robotics(row): |
|
|
|
return row['has_robot'] |
|
|
|
def is_biomed(row): |
|
|
|
return row['is_biomed'] |
|
|
|
def is_timeseries(row): |
|
|
|
return row['has_series'] |
|
|
|
def is_science(row): |
|
|
|
return row['has_science'] |
|
|
|
def is_video(row): |
|
|
|
return row['has_video'] |
|
|
|
def is_image(row): |
|
|
|
return row['has_image'] |
|
|
|
def is_text(row): |
|
|
|
return row['has_text'] |
|
|
|
def is_image(row): |
|
tags = row.get("tags", []) |
|
|
|
|
|
if tags is not None: |
|
|
|
if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'): |
|
|
|
tags_list = tags.tolist() |
|
return any("image" in str(tag).lower() for tag in tags_list) |
|
|
|
elif isinstance(tags, list): |
|
return any("image" in str(tag).lower() for tag in tags) |
|
|
|
elif isinstance(tags, str): |
|
return "image" in tags.lower() |
|
return False |
|
|
|
def is_text(row): |
|
tags = row.get("tags", []) |
|
|
|
|
|
if tags is not None: |
|
|
|
if hasattr(tags, 'dtype') and hasattr(tags, 'tolist'): |
|
|
|
tags_list = tags.tolist() |
|
return any("text" in str(tag).lower() for tag in tags_list) |
|
|
|
elif isinstance(tags, list): |
|
return any("text" in str(tag).lower() for tag in tags) |
|
|
|
elif isinstance(tags, str): |
|
return "text" in tags.lower() |
|
return False |
|
|
|
def extract_model_size(safetensors_data): |
|
"""Extract model size in GB from safetensors data""" |
|
try: |
|
if pd.isna(safetensors_data): |
|
return 0 |
|
|
|
|
|
if isinstance(safetensors_data, dict): |
|
if 'total' in safetensors_data: |
|
try: |
|
size_bytes = float(safetensors_data['total']) |
|
return size_bytes / (1024 * 1024 * 1024) |
|
except (ValueError, TypeError): |
|
pass |
|
|
|
|
|
elif isinstance(safetensors_data, str): |
|
try: |
|
data_dict = json.loads(safetensors_data) |
|
if 'total' in data_dict: |
|
try: |
|
size_bytes = float(data_dict['total']) |
|
return size_bytes / (1024 * 1024 * 1024) |
|
except (ValueError, TypeError): |
|
pass |
|
except: |
|
pass |
|
|
|
return 0 |
|
except Exception as e: |
|
print(f"Error extracting model size: {e}") |
|
return 0 |
|
|
|
|
|
def is_in_size_range(row, size_range): |
|
"""Check if a model is in the specified size range using pre-calculated size category""" |
|
if size_range is None or size_range == "None": |
|
return True |
|
|
|
|
|
return row['size_category'] == size_range |
|
|
|
TAG_FILTER_FUNCS = { |
|
"Audio & Speech": is_audio_speech, |
|
"Time series": is_timeseries, |
|
"Robotics": is_robotics, |
|
"Music": is_music, |
|
"Video": is_video, |
|
"Images": is_image, |
|
"Text": is_text, |
|
"Biomedical": is_biomed, |
|
"Sciences": is_science, |
|
} |
|
|
|
def extract_org_from_id(model_id): |
|
"""Extract organization name from model ID""" |
|
if "/" in model_id: |
|
return model_id.split("/")[0] |
|
return "unaffiliated" |
|
|
|
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None): |
|
"""Process DataFrame into treemap format with filters applied - OPTIMIZED with cached columns""" |
|
|
|
filtered_df = df.copy() |
|
|
|
|
|
filter_stats = {"initial": len(filtered_df)} |
|
start_time = pd.Timestamp.now() |
|
|
|
|
|
if tag_filter and tag_filter in TAG_FILTER_FUNCS: |
|
print(f"Applying tag filter: {tag_filter}") |
|
|
|
|
|
if tag_filter == "Audio & Speech": |
|
filtered_df = filtered_df[filtered_df['is_audio_speech']] |
|
elif tag_filter == "Music": |
|
filtered_df = filtered_df[filtered_df['has_music']] |
|
elif tag_filter == "Robotics": |
|
filtered_df = filtered_df[filtered_df['has_robot']] |
|
elif tag_filter == "Biomedical": |
|
filtered_df = filtered_df[filtered_df['is_biomed']] |
|
elif tag_filter == "Time series": |
|
filtered_df = filtered_df[filtered_df['has_series']] |
|
elif tag_filter == "Sciences": |
|
filtered_df = filtered_df[filtered_df['has_science']] |
|
elif tag_filter == "Video": |
|
filtered_df = filtered_df[filtered_df['has_video']] |
|
elif tag_filter == "Images": |
|
filtered_df = filtered_df[filtered_df['has_image']] |
|
elif tag_filter == "Text": |
|
filtered_df = filtered_df[filtered_df['has_text']] |
|
|
|
filter_stats["after_tag_filter"] = len(filtered_df) |
|
print(f"Tag filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds") |
|
start_time = pd.Timestamp.now() |
|
|
|
|
|
if pipeline_filter: |
|
print(f"Applying pipeline filter: {pipeline_filter}") |
|
filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter] |
|
filter_stats["after_pipeline_filter"] = len(filtered_df) |
|
print(f"Pipeline filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds") |
|
start_time = pd.Timestamp.now() |
|
|
|
|
|
if size_filter and size_filter in MODEL_SIZE_RANGES: |
|
print(f"Applying size filter: {size_filter}") |
|
|
|
|
|
filtered_df = filtered_df[filtered_df['size_category'] == size_filter] |
|
|
|
|
|
print(f"Size filter '{size_filter}' applied.") |
|
print(f"Models after size filter: {len(filtered_df)}") |
|
|
|
filter_stats["after_size_filter"] = len(filtered_df) |
|
print(f"Size filter applied in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds") |
|
start_time = pd.Timestamp.now() |
|
|
|
|
|
filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id) |
|
|
|
|
|
if skip_orgs and len(skip_orgs) > 0: |
|
filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)] |
|
filter_stats["after_skip_orgs"] = len(filtered_df) |
|
|
|
|
|
print("Filter statistics:") |
|
for stage, count in filter_stats.items(): |
|
print(f" {stage}: {count} models") |
|
|
|
|
|
if filtered_df.empty: |
|
print("Warning: No data left after applying filters!") |
|
return pd.DataFrame() |
|
|
|
|
|
org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index() |
|
org_totals = org_totals.sort_values(by=count_by, ascending=False) |
|
|
|
|
|
top_orgs = org_totals.head(top_k)["organization"].tolist() |
|
|
|
|
|
filtered_df = filtered_df[filtered_df["organization"].isin(top_orgs)] |
|
|
|
|
|
treemap_data = filtered_df[["id", "organization", count_by]].copy() |
|
|
|
|
|
treemap_data["root"] = "models" |
|
|
|
|
|
treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0) |
|
|
|
print(f"Treemap data prepared in {(pd.Timestamp.now() - start_time).total_seconds():.3f} seconds") |
|
return treemap_data |
|
|
|
def create_treemap(treemap_data, count_by, title=None): |
|
"""Create a Plotly treemap from the prepared data""" |
|
if treemap_data.empty: |
|
|
|
fig = px.treemap( |
|
names=["No data matches the selected filters"], |
|
values=[1] |
|
) |
|
fig.update_layout( |
|
title="No data matches the selected filters", |
|
margin=dict(t=50, l=25, r=25, b=25) |
|
) |
|
return fig |
|
|
|
|
|
fig = px.treemap( |
|
treemap_data, |
|
path=["root", "organization", "id"], |
|
values=count_by, |
|
title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization", |
|
color_discrete_sequence=px.colors.qualitative.Plotly |
|
) |
|
|
|
|
|
fig.update_layout( |
|
margin=dict(t=50, l=25, r=25, b=25) |
|
) |
|
|
|
|
|
fig.update_traces( |
|
textinfo="label+value+percent root", |
|
hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>" |
|
) |
|
|
|
return fig |
|
|
|
def load_models_data(): |
|
"""Load models data from Hugging Face using DuckDB with caching for improved performance""" |
|
try: |
|
|
|
parquet_url = "https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet" |
|
|
|
print("Fetching data from Hugging Face models.parquet...") |
|
|
|
|
|
|
|
try: |
|
query = """ |
|
SELECT |
|
id, |
|
downloads, |
|
downloadsAllTime, |
|
likes, |
|
pipeline_tag, |
|
tags, |
|
safetensors |
|
FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet') |
|
""" |
|
df = duckdb.sql(query).df() |
|
except Exception as sql_error: |
|
print(f"Error with specific column selection: {sql_error}") |
|
|
|
print("Falling back to select * query...") |
|
query = "SELECT * FROM read_parquet('https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet')" |
|
raw_df = duckdb.sql(query).df() |
|
|
|
|
|
needed_columns = ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'safetensors'] |
|
available_columns = set(raw_df.columns) |
|
df = pd.DataFrame() |
|
|
|
|
|
for col in needed_columns: |
|
if col in available_columns: |
|
df[col] = raw_df[col] |
|
else: |
|
|
|
if col in ['downloads', 'downloadsAllTime', 'likes']: |
|
df[col] = 0 |
|
elif col == 'pipeline_tag': |
|
df[col] = '' |
|
elif col == 'tags': |
|
df[col] = [[] for _ in range(len(raw_df))] |
|
elif col == 'safetensors': |
|
df[col] = None |
|
elif col == 'id': |
|
|
|
df[col] = [f"model_{i}" for i in range(len(raw_df))] |
|
|
|
print(f"Data fetched successfully. Shape: {df.shape}") |
|
|
|
|
|
if 'safetensors' in df.columns: |
|
|
|
df['params'] = df['safetensors'].apply(extract_model_size) |
|
|
|
|
|
size_ranges = { |
|
"Small (<1GB)": 0, |
|
"Medium (1-5GB)": 0, |
|
"Large (5-20GB)": 0, |
|
"X-Large (20-50GB)": 0, |
|
"XX-Large (>50GB)": 0 |
|
} |
|
|
|
|
|
for idx, row in df.iterrows(): |
|
size_gb = row['params'] |
|
if 0 <= size_gb < 1: |
|
size_ranges["Small (<1GB)"] += 1 |
|
elif 1 <= size_gb < 5: |
|
size_ranges["Medium (1-5GB)"] += 1 |
|
elif 5 <= size_gb < 20: |
|
size_ranges["Large (5-20GB)"] += 1 |
|
elif 20 <= size_gb < 50: |
|
size_ranges["X-Large (20-50GB)"] += 1 |
|
elif size_gb >= 50: |
|
size_ranges["XX-Large (>50GB)"] += 1 |
|
|
|
print("Model size distribution:") |
|
for size_range, count in size_ranges.items(): |
|
print(f" {size_range}: {count} models") |
|
|
|
|
|
def get_size_category(size_gb): |
|
if 0 <= size_gb < 1: |
|
return "Small (<1GB)" |
|
elif 1 <= size_gb < 5: |
|
return "Medium (1-5GB)" |
|
elif 5 <= size_gb < 20: |
|
return "Large (5-20GB)" |
|
elif 20 <= size_gb < 50: |
|
return "X-Large (20-50GB)" |
|
elif size_gb >= 50: |
|
return "XX-Large (>50GB)" |
|
return None |
|
|
|
|
|
df['size_category'] = df['params'].apply(get_size_category) |
|
|
|
|
|
df = df.drop(columns=['safetensors']) |
|
else: |
|
|
|
df['params'] = 0 |
|
df['size_category'] = None |
|
|
|
|
|
def process_tags(tags_value): |
|
try: |
|
if pd.isna(tags_value) or tags_value is None: |
|
return [] |
|
|
|
|
|
if hasattr(tags_value, 'dtype') and hasattr(tags_value, 'tolist'): |
|
|
|
return [str(tag) for tag in tags_value.tolist()] |
|
|
|
|
|
if isinstance(tags_value, list): |
|
return [str(tag) for tag in tags_value] |
|
|
|
|
|
if isinstance(tags_value, str): |
|
try: |
|
tags_list = json.loads(tags_value) |
|
if isinstance(tags_list, list): |
|
return [str(tag) for tag in tags_list] |
|
except: |
|
|
|
return [tag.strip() for tag in tags_value.split(',') if tag.strip()] |
|
|
|
|
|
return [str(tags_value)] |
|
|
|
except Exception as e: |
|
print(f"Error processing tags: {e}") |
|
return [] |
|
|
|
|
|
if 'tags' in df.columns: |
|
|
|
df['tags'] = df['tags'].apply(process_tags) |
|
|
|
|
|
print("Pre-calculating cached tag categories...") |
|
|
|
|
|
def has_audio_tag(tags): |
|
if tags and isinstance(tags, list): |
|
return any("audio" in str(tag).lower() for tag in tags) |
|
return False |
|
|
|
def has_speech_tag(tags): |
|
if tags and isinstance(tags, list): |
|
return any("speech" in str(tag).lower() for tag in tags) |
|
return False |
|
|
|
def has_music_tag(tags): |
|
if tags and isinstance(tags, list): |
|
return any("music" in str(tag).lower() for tag in tags) |
|
return False |
|
|
|
def has_robot_tag(tags): |
|
if tags and isinstance(tags, list): |
|
return any("robot" in str(tag).lower() for tag in tags) |
|
return False |
|
|
|
def has_bio_tag(tags): |
|
if tags and isinstance(tags, list): |
|
return any("bio" in str(tag).lower() for tag in tags) |
|
return False |
|
|
|
def has_med_tag(tags): |
|
if tags and isinstance(tags, list): |
|
return any("medic" in str(tag).lower() for tag in tags) |
|
return False |
|
|
|
def has_series_tag(tags): |
|
if tags and isinstance(tags, list): |
|
return any("series" in str(tag).lower() for tag in tags) |
|
return False |
|
|
|
def has_science_tag(tags): |
|
if tags and isinstance(tags, list): |
|
return any("science" in str(tag).lower() and "bigscience" not in str(tag).lower() for tag in tags) |
|
return False |
|
|
|
def has_video_tag(tags): |
|
if tags and isinstance(tags, list): |
|
return any("video" in str(tag).lower() for tag in tags) |
|
return False |
|
|
|
def has_image_tag(tags): |
|
if tags and isinstance(tags, list): |
|
return any("image" in str(tag).lower() for tag in tags) |
|
return False |
|
|
|
def has_text_tag(tags): |
|
if tags and isinstance(tags, list): |
|
return any("text" in str(tag).lower() for tag in tags) |
|
return False |
|
|
|
|
|
print("Creating cached tag columns...") |
|
df['has_audio'] = df['tags'].apply(has_audio_tag) |
|
df['has_speech'] = df['tags'].apply(has_speech_tag) |
|
df['has_music'] = df['tags'].apply(has_music_tag) |
|
df['has_robot'] = df['tags'].apply(has_robot_tag) |
|
df['has_bio'] = df['tags'].apply(has_bio_tag) |
|
df['has_med'] = df['tags'].apply(has_med_tag) |
|
df['has_series'] = df['tags'].apply(has_series_tag) |
|
df['has_science'] = df['tags'].apply(has_science_tag) |
|
df['has_video'] = df['tags'].apply(has_video_tag) |
|
df['has_image'] = df['tags'].apply(has_image_tag) |
|
df['has_text'] = df['tags'].apply(has_text_tag) |
|
|
|
|
|
df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] | |
|
df['pipeline_tag'].str.contains('audio', case=False, na=False) | |
|
df['pipeline_tag'].str.contains('speech', case=False, na=False)) |
|
df['is_biomed'] = df['has_bio'] | df['has_med'] |
|
|
|
print("Cached tag columns created successfully!") |
|
else: |
|
|
|
df['tags'] = [[] for _ in range(len(df))] |
|
for col in ['has_audio', 'has_speech', 'has_music', 'has_robot', |
|
'has_bio', 'has_med', 'has_series', 'has_science', |
|
'has_video', 'has_image', 'has_text', |
|
'is_audio_speech', 'is_biomed']: |
|
df[col] = False |
|
|
|
|
|
df.fillna({'downloads': 0, 'downloadsAllTime': 0, 'likes': 0, 'params': 0}, inplace=True) |
|
|
|
|
|
if 'pipeline_tag' in df.columns: |
|
df['pipeline_tag'] = df['pipeline_tag'].fillna('') |
|
else: |
|
df['pipeline_tag'] = '' |
|
|
|
|
|
for col in ['id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params']: |
|
if col not in df.columns: |
|
if col in ['downloads', 'downloadsAllTime', 'likes', 'params']: |
|
df[col] = 0 |
|
elif col == 'pipeline_tag': |
|
df[col] = '' |
|
elif col == 'tags': |
|
df[col] = [[] for _ in range(len(df))] |
|
elif col == 'id': |
|
df[col] = [f"model_{i}" for i in range(len(df))] |
|
|
|
print(f"Successfully processed {len(df)} models with cached tag and size information") |
|
return df, True |
|
|
|
except Exception as e: |
|
print(f"Error loading data: {e}") |
|
|
|
return pd.DataFrame(), False |
|
|
|
|
|
with gr.Blocks() as demo: |
|
models_data = gr.State() |
|
loading_complete = gr.State(False) |
|
|
|
with gr.Row(): |
|
gr.Markdown(""" |
|
# HuggingFace Models TreeMap Visualization |
|
|
|
This app shows how different organizations contribute to the HuggingFace ecosystem with their models. |
|
Use the filters to explore models by different metrics, tags, pipelines, and model sizes. |
|
|
|
The treemap visualizes models grouped by organization, with the size of each box representing the selected metric. |
|
|
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
count_by_dropdown = gr.Dropdown( |
|
label="Metric", |
|
choices=[ |
|
("Downloads (last 30 days)", "downloads"), |
|
("Downloads (All Time)", "downloadsAllTime"), |
|
("Likes", "likes") |
|
], |
|
value="downloads", |
|
info="Select the metric to determine box sizes" |
|
) |
|
|
|
filter_choice_radio = gr.Radio( |
|
label="Filter Type", |
|
choices=["None", "Tag Filter", "Pipeline Filter"], |
|
value="None", |
|
info="Choose how to filter the models" |
|
) |
|
|
|
tag_filter_dropdown = gr.Dropdown( |
|
label="Select Tag", |
|
choices=list(TAG_FILTER_FUNCS.keys()), |
|
value=None, |
|
visible=False, |
|
info="Filter models by domain/category" |
|
) |
|
|
|
pipeline_filter_dropdown = gr.Dropdown( |
|
label="Select Pipeline Tag", |
|
choices=PIPELINE_TAGS, |
|
value=None, |
|
visible=False, |
|
info="Filter models by specific pipeline" |
|
) |
|
|
|
size_filter_dropdown = gr.Dropdown( |
|
label="Model Size Filter", |
|
choices=["None"] + list(MODEL_SIZE_RANGES.keys()), |
|
value="None", |
|
info="Filter models by their size (using params column)" |
|
) |
|
|
|
top_k_slider = gr.Slider( |
|
label="Number of Top Organizations", |
|
minimum=5, |
|
maximum=50, |
|
value=25, |
|
step=5, |
|
info="Number of top organizations to include" |
|
) |
|
|
|
skip_orgs_textbox = gr.Textbox( |
|
label="Organizations to Skip (comma-separated)", |
|
placeholder="e.g., OpenAI, Google", |
|
value="TheBloke, MaziyarPanahi, unsloth, modularai, Gensyn, bartowski" |
|
) |
|
|
|
generate_plot_button = gr.Button("Generate Plot", variant="primary", interactive=False) |
|
refresh_data_button = gr.Button("Refresh Data from Hugging Face", variant="secondary") |
|
|
|
with gr.Column(scale=3): |
|
plot_output = gr.Plot() |
|
stats_output = gr.Markdown("*Loading data from Hugging Face...*") |
|
data_info = gr.Markdown("") |
|
|
|
|
|
def enable_plot_button(loaded): |
|
return gr.update(interactive=loaded) |
|
|
|
loading_complete.change( |
|
fn=enable_plot_button, |
|
inputs=[loading_complete], |
|
outputs=[generate_plot_button] |
|
) |
|
|
|
|
|
def update_filter_visibility(filter_choice): |
|
if filter_choice == "Tag Filter": |
|
return gr.update(visible=True), gr.update(visible=False) |
|
elif filter_choice == "Pipeline Filter": |
|
return gr.update(visible=False), gr.update(visible=True) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=False) |
|
|
|
filter_choice_radio.change( |
|
fn=update_filter_visibility, |
|
inputs=[filter_choice_radio], |
|
outputs=[tag_filter_dropdown, pipeline_filter_dropdown] |
|
) |
|
|
|
|
|
def load_and_provide_info(): |
|
df, success = load_models_data() |
|
|
|
if success: |
|
|
|
info_text = f""" |
|
### Data Information |
|
- **Total models loaded**: {len(df):,} |
|
- **Last update**: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')} |
|
- **Data source**: [Hugging Face Hub Stats](https://huggingface.co/datasets/cfahlgren1/hub-stats) (models.parquet) |
|
""" |
|
|
|
|
|
return df, True, info_text, "*Data loaded successfully. Use the controls to generate a plot.*" |
|
else: |
|
|
|
return pd.DataFrame(), False, "*Error loading data from Hugging Face.*", "*Failed to load data. Please try again.*" |
|
|
|
|
|
def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df): |
|
if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty: |
|
return None, "Error: Data is still loading. Please wait a moment and try again." |
|
|
|
selected_tag_filter = None |
|
selected_pipeline_filter = None |
|
selected_size_filter = None |
|
|
|
if filter_choice == "Tag Filter": |
|
selected_tag_filter = tag_filter |
|
elif filter_choice == "Pipeline Filter": |
|
selected_pipeline_filter = pipeline_filter |
|
|
|
if size_filter != "None": |
|
selected_size_filter = size_filter |
|
|
|
skip_orgs = [] |
|
if skip_orgs_text and skip_orgs_text.strip(): |
|
skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()] |
|
|
|
treemap_data = make_treemap_data( |
|
df=data_df, |
|
count_by=count_by, |
|
top_k=top_k, |
|
tag_filter=selected_tag_filter, |
|
pipeline_filter=selected_pipeline_filter, |
|
size_filter=selected_size_filter, |
|
skip_orgs=skip_orgs |
|
) |
|
|
|
title_labels = { |
|
"downloads": "Downloads (last 30 days)", |
|
"downloadsAllTime": "Downloads (All Time)", |
|
"likes": "Likes" |
|
} |
|
title_text = f"HuggingFace Models - {title_labels.get(count_by, count_by)} by Organization" |
|
|
|
fig = create_treemap( |
|
treemap_data=treemap_data, |
|
count_by=count_by, |
|
title=title_text |
|
) |
|
|
|
if treemap_data.empty: |
|
stats_md = "No data matches the selected filters." |
|
else: |
|
total_models = len(treemap_data) |
|
total_value = treemap_data[count_by].sum() |
|
|
|
|
|
top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5) |
|
|
|
|
|
top_5_models = treemap_data[["id", count_by]].sort_values(by=count_by, ascending=False).head(5) |
|
|
|
|
|
stats_md = f""" |
|
## Statistics |
|
- **Total models shown**: {total_models:,} |
|
- **Total {count_by}**: {int(total_value):,} |
|
|
|
## Top Organizations by {count_by.capitalize()} |
|
|
|
| Organization | {count_by.capitalize()} | % of Total | |
|
|--------------|-------------:|----------:| |
|
""" |
|
|
|
|
|
for org, value in top_5_orgs.items(): |
|
percentage = (value / total_value) * 100 |
|
stats_md += f"| {org} | {int(value):,} | {percentage:.2f}% |\n" |
|
|
|
|
|
stats_md += f""" |
|
## Top Models by {count_by.capitalize()} |
|
|
|
| Model | {count_by.capitalize()} | % of Total | |
|
|-------|-------------:|----------:| |
|
""" |
|
|
|
|
|
for _, row in top_5_models.iterrows(): |
|
model_id = row["id"] |
|
value = row[count_by] |
|
percentage = (value / total_value) * 100 |
|
stats_md += f"| {model_id} | {int(value):,} | {percentage:.2f}% |\n" |
|
|
|
|
|
if skip_orgs: |
|
stats_md += f"\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*" |
|
|
|
return fig, stats_md |
|
|
|
|
|
demo.load( |
|
fn=load_and_provide_info, |
|
inputs=[], |
|
outputs=[models_data, loading_complete, data_info, stats_output] |
|
) |
|
|
|
|
|
refresh_data_button.click( |
|
fn=load_and_provide_info, |
|
inputs=[], |
|
outputs=[models_data, loading_complete, data_info, stats_output] |
|
) |
|
|
|
generate_plot_button.click( |
|
fn=generate_plot_on_click, |
|
inputs=[ |
|
count_by_dropdown, |
|
filter_choice_radio, |
|
tag_filter_dropdown, |
|
pipeline_filter_dropdown, |
|
size_filter_dropdown, |
|
top_k_slider, |
|
skip_orgs_textbox, |
|
models_data |
|
], |
|
outputs=[plot_output, stats_output] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |