File size: 4,082 Bytes
0da2f8d
d4fbda3
 
0da2f8d
 
d4fbda3
 
 
 
 
 
 
 
 
 
 
0da2f8d
d4fbda3
 
0da2f8d
d4fbda3
 
 
bd3fe94
 
 
 
d4fbda3
 
 
 
0da2f8d
d4fbda3
 
 
 
bd3fe94
 
 
 
 
 
 
 
d4fbda3
 
 
bd3fe94
 
d4fbda3
 
 
 
 
 
bd3fe94
 
 
 
 
 
 
d4fbda3
bd3fe94
d4fbda3
8c6fcfb
 
 
d4fbda3
 
 
bd3fe94
d4fbda3
 
 
 
 
 
 
 
 
 
 
0da2f8d
 
4a901f1
 
d4fbda3
bd3fe94
d4fbda3
 
 
 
 
 
 
 
8c6fcfb
 
 
 
 
d4fbda3
 
 
 
8c6fcfb
 
 
 
d4fbda3
 
 
 
 
 
 
 
 
0da2f8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import tempfile
import csv
import pandas as pd
import gradio as gr
from huggingface_hub import HfApi
from pathlib import Path

def get_model_stats(search_term):
    # Initialize the API
    api = HfApi()
    
    # Create a temporary file for the CSV
    temp_dir = tempfile.mkdtemp()
    output_file = Path(temp_dir) / f"{search_term}_models_alltime.csv"
    
    # Get the generator of models with the working sort parameter
    print(f"Fetching {search_term} models with download statistics...")
    models_generator = api.list_models(
        search=search_term, 
        expand=["downloads", "downloadsAllTime"],  # Get both 30-day and all-time downloads
        sort="_id"  # Sort by ID to avoid timeout issues
    )
    
    # Initialize counters for total downloads
    total_30day_downloads = 0
    total_alltime_downloads = 0
    
    # Create and write to CSV
    with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
        csv_writer = csv.writer(csvfile)
        # Write header
        csv_writer.writerow(["Model ID", "Downloads (30 days)", "Downloads (All Time)"])
        
        # Process models
        model_count = 0
        for model in models_generator:
            # Get download counts
            downloads_30day = getattr(model, 'downloads', 0)
            downloads_alltime = getattr(model, 'downloads_all_time', 0)
            
            # Add to totals
            total_30day_downloads += downloads_30day
            total_alltime_downloads += downloads_alltime
            
            # Write to CSV
            csv_writer.writerow([
                getattr(model, 'id', "Unknown"),
                downloads_30day,
                downloads_alltime
            ])
            model_count += 1
    
    # Read the CSV file into a pandas DataFrame
    df = pd.read_csv(output_file)
    
    # Create status message with total downloads
    status_message = (
        f"Found {model_count} models for search term '{search_term}'\n"
        f"Total 30-day downloads: {total_30day_downloads:,}\n"
        f"Total all-time downloads: {total_alltime_downloads:,}"
    )
    
    # Return both the DataFrame, status message, and the CSV file path
    return df, status_message, str(output_file)

def format_model_link(model_id):
    return f'<a href="https://huggingface.co/{model_id}" target="_blank">{model_id}</a>'

# Create the Gradio interface
with gr.Blocks(title="Hugging Face Model Statistics") as demo:
    gr.Markdown("# Hugging Face Model Statistics")
    gr.Markdown("Enter a search term to find model statistics from Hugging Face Hub")
    
    with gr.Row():
        search_input = gr.Textbox(
            label="Search Term",
            placeholder="Enter a model name or keyword (e.g., 'gemma', 'llama')",
            value="gemma"
        )
        search_button = gr.Button("Search")
    
    with gr.Row():
        output_table = gr.Dataframe(
            headers=["Model ID", "Downloads (30 days)", "Downloads (All Time)"],
            datatype=["str", "number", "number"],
            label="Model Statistics",
            wrap=True
        )
        status_message = gr.Textbox(label="Status", lines=3)  # Increased lines to show all stats
    
    with gr.Row():
        download_button = gr.Button("Download CSV")
        csv_file = gr.File(label="CSV File", visible=False)
    
    # Store the CSV file path in a state
    csv_path = gr.State()
    
    def process_results(df, status, csv_path):
        # Format the model IDs as clickable links
        df['Model ID'] = df['Model ID'].apply(format_model_link)
        return df, status, csv_path
    
    search_button.click(
        fn=get_model_stats,
        inputs=search_input,
        outputs=[output_table, status_message, csv_path]
    ).then(
        fn=process_results,
        inputs=[output_table, status_message, csv_path],
        outputs=[output_table, status_message, csv_path]
    )
    
    download_button.click(
        fn=lambda x: x,
        inputs=csv_path,
        outputs=csv_file
    )

if __name__ == "__main__":
    demo.launch()