File size: 9,737 Bytes
bbf45d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import json
import gradio as gr
import pandas as pd
import plotly.express as px

PIPELINE_TAGS = [
 'text-generation',
 'text-to-image',
 'text-classification',
 'text2text-generation',
 'audio-to-audio',
 'feature-extraction',
 'image-classification',
 'translation',
 'reinforcement-learning',
 'fill-mask',
 'text-to-speech',
 'automatic-speech-recognition',
 'image-text-to-text',
 'token-classification',
 'sentence-similarity',
 'question-answering',
 'image-feature-extraction',
 'summarization',
 'zero-shot-image-classification',
 'object-detection',
 'image-segmentation',
 'image-to-image',
 'image-to-text',
 'audio-classification',
 'visual-question-answering',
 'text-to-video',
 'zero-shot-classification',
 'depth-estimation',
 'text-ranking',
 'image-to-video',
 'multiple-choice',
 'unconditional-image-generation',
 'video-classification',
 'text-to-audio',
 'time-series-forecasting',
 'any-to-any',
 'video-text-to-text',
 'table-question-answering',
]

def is_audio_speech(repo_dct):
    res = (repo_dct.get("pipeline_tag", None) and "audio" in repo_dct.get("pipeline_tag", "").lower()) or \
        (repo_dct.get("pipeline_tag", None) and "speech" in repo_dct.get("pipeline_tag", "").lower()) or \
        (repo_dct.get("tags", None) and any("audio" in tag.lower() for tag in repo_dct.get("tags", []))) or \
        (repo_dct.get("tags", None) and any("speech" in tag.lower() for tag in repo_dct.get("tags", [])))
    return res

def is_music(repo_dct):
    res = (repo_dct.get("tags", None) and any("music" in tag.lower() for tag in repo_dct.get("tags", [])))
    return res

def is_robotics(repo_dct):
    res = (repo_dct.get("tags", None) and any("robot" in tag.lower() for tag in repo_dct.get("tags", [])))
    return res

def is_biomed(repo_dct):
    res = (repo_dct.get("tags", None) and any("bio" in tag.lower() for tag in repo_dct.get("tags", []))) or \
        (repo_dct.get("tags", None) and any("medic" in tag.lower() for tag in repo_dct.get("tags", [])))
    return res

def is_timeseries(repo_dct):
    res = (repo_dct.get("tags", None) and any("series" in tag.lower() for tag in repo_dct.get("tags", [])))
    return res

def is_science(repo_dct):
    res = (repo_dct.get("tags", None) and any("science" in tag.lower() and not "bigscience" in tag for tag in repo_dct.get("tags", [])))
    return res

def is_video(repo_dct):
    res = (repo_dct.get("tags", None) and any("video" in tag.lower() for tag in repo_dct.get("tags", [])))
    return res

def is_image(repo_dct):
    res = (repo_dct.get("tags", None) and any("image" in tag.lower() for tag in repo_dct.get("tags", [])))
    return res

def is_text(repo_dct):
    res = (repo_dct.get("tags", None) and any("text" in tag.lower() for tag in repo_dct.get("tags", [])))
    return res

TAG_FILTER_FUNCS = {
    "Audio & Speech": is_audio_speech,
    "Time series": is_timeseries,
    "Robotics": is_robotics,
    "Music": is_music,
    "Video": is_video,
    "Images": is_image,
    "Text": is_text,
    "Biomedical": is_biomed,
    "Sciences": is_science,
}

def make_org_stats(repo_type, count_by, org_stats, top_k=20, filter_func=None):
    assert count_by in ["likes", "downloads", "downloads_all"]
    assert repo_type in ["all", "datasets", "models"]
    repos = ["datasets", "models"] if repo_type == "all" else [repo_type]
    if filter_func is None:
        filter_func = lambda x: True
    sorted_stats = sorted(
        [(
            author,
            sum(dct[count_by] for dct in author_dct[repo] if filter_func(dct))
        ) for repo in repos for author, author_dct in org_stats.items()],
        key=lambda x:x[1],
        reverse=True,
    )
    res = sorted_stats[:top_k] + [("Others...", sum(st for auth, st in sorted_stats[top_k:]))]
    total_st = sum(st for o, st in res)
    res_plot_df = []
    for org, st in res:
        if org == "Others...":
            res_plot_df += [("Others...", "other", st * 100 / total_st)]
        else:
            for repo in repos:
                for dct in org_stats[org][repo]:
                    if filter_func(dct):
                        res_plot_df += [(org, dct["id"], dct[count_by] * 100 / total_st)]
    return ([(o, 100 * st / total_st) for o, st in res if st > 0], res_plot_df)

def make_figure(count_by, repo_type, org_stats, tag_filter=None, pipeline_filter=None):
    assert count_by in ["downloads", "likes", "downloads_all"]
    assert repo_type in ["all", "models", "datasets"]
    assert tag_filter is None or pipeline_filter is None
    filter_func = None
    if tag_filter:
        filter_func = TAG_FILTER_FUNCS[tag_filter]
    if pipeline_filter:
        filter_func = lambda dct: dct.get("pipeline_tag", None) and dct.get("pipeline_tag", "") == pipeline_filter
    _, res_plot_df = make_org_stats(repo_type, count_by, org_stats, top_k=25, filter_func=filter_func)
    df = pd.DataFrame(
        dict(
            organizations=[o for o, _, _ in res_plot_df],
            repo=[r for _, r, _ in res_plot_df],
            stats=[s for _, _, s in res_plot_df],
        )
    )
    df[repo_type] = repo_type # in order to have a single root node
    fig = px.treemap(df, path=[repo_type, 'organizations', 'repo'], values='stats')
    fig.update_layout(
        treemapcolorway = ["pink" for _ in range(len(res_plot_df))],
        margin = dict(t=50, l=25, r=25, b=25)
    )
    return fig


with gr.Blocks() as demo:
    org_stats_data = gr.State(value=None)  # To store loaded data

    with gr.Row():
        gr.Markdown("""
            ## Hugging Face Organization Stats
            
            This app shows how different organizations are contributing to different aspects of the open AI ecosystem.
            Use the dropdowns on the left to select repository types, metrics, and optionally tags representing topics or modalities of interest.
        """)
    with gr.Row():
        with gr.Column(scale=1):
            repo_type_dropdown = gr.Dropdown(
                label="Repository Type",
                choices=["all", "models", "datasets"],
                value="all"
            )
            count_by_dropdown = gr.Dropdown(
                label="Metric",
                choices=["downloads", "likes", "downloads_all"],
                value="downloads"
            )
            
            filter_choice_radio = gr.Radio(
                label="Filter by",
                choices=["None", "Tag Filter", "Pipeline Filter"],
                value="None"
            )
            
            tag_filter_dropdown = gr.Dropdown(
                label="Select Tag",
                choices=list(TAG_FILTER_FUNCS.keys()),
                value=None,
                visible=False
            )
            pipeline_filter_dropdown = gr.Dropdown(
                label="Select Pipeline Tag",
                choices=PIPELINE_TAGS,
                value=None,
                visible=False
            )

            generate_plot_button = gr.Button("Generate Plot")

        with gr.Column(scale=3):
            plot_output = gr.Plot()

    def generate_plot_on_click(repo_type, count_by, filter_choice, tag_filter, pipeline_filter, data):
        # Print the current state of the input variables
        print(f"Generating plot with the following inputs:")
        print(f"  Repository Type: {repo_type}")
        print(f"  Metric (Count By): {count_by}")
        print(f"  Filter Choice: {filter_choice}")
        if filter_choice == "Tag Filter":
            print(f"  Tag Filter: {tag_filter}")
        elif filter_choice == "Pipeline Filter":
            print(f"  Pipeline Filter: {pipeline_filter}")

        if data is None:
            print("Error: Data not loaded yet.")
            return None 
        
        selected_tag_filter = None
        selected_pipeline_filter = None

        if filter_choice == "Tag Filter":
            selected_tag_filter = tag_filter
        elif filter_choice == "Pipeline Filter":
            selected_pipeline_filter = pipeline_filter
            
        fig = make_figure(
            count_by=count_by,
            repo_type=repo_type,
            org_stats=data,
            tag_filter=selected_tag_filter,
            pipeline_filter=selected_pipeline_filter
        )
        return fig

    def update_filter_visibility(filter_choice):
        if filter_choice == "Tag Filter":
            return gr.update(visible=True), gr.update(visible=False)
        elif filter_choice == "Pipeline Filter":
            return gr.update(visible=False), gr.update(visible=True)
        else: # "None"
            return gr.update(visible=False), gr.update(visible=False)

    filter_choice_radio.change(
        fn=update_filter_visibility,
        inputs=[filter_choice_radio],
        outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
    )
    
    # Load data once at startup
    def load_org_data():
        print("Loading organization statistics data...")
        loaded_org_stats = json.load(open("org_to_artifacts_2l_stats.json"))
        print("Data loaded successfully.")
        return loaded_org_stats

    demo.load(
        fn=load_org_data,
        inputs=[], # No inputs needed to just load data
        outputs=[org_stats_data] # Only output to the state
    )

    # Button click event to generate plot
    generate_plot_button.click(
        fn=generate_plot_on_click,
        inputs=[
            repo_type_dropdown,
            count_by_dropdown,
            filter_choice_radio,
            tag_filter_dropdown,
            pipeline_filter_dropdown,
            org_stats_data
        ],
        outputs=[plot_output]
    )


if __name__ == "__main__":
    # org_stats = json.load(open("org_to_artifacts_2l_stats.json")) # Data loading handled by demo.load
    demo.launch()