|
from datasets import load_dataset |
|
import pandas as pd |
|
import duckdb |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
import gradio as gr |
|
import os |
|
from huggingface_hub import login |
|
from datetime import datetime, timedelta |
|
import sys |
|
|
|
|
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
if not HF_TOKEN: |
|
raise ValueError("Please set the HF_TOKEN environment variable") |
|
|
|
|
|
login(token=HF_TOKEN) |
|
|
|
|
|
sns.set_theme(style="whitegrid") |
|
sns.set_context("notebook") |
|
|
|
|
|
try: |
|
dataset = load_dataset("reach-vb/trending-repos", split="models") |
|
df = dataset.to_pandas() |
|
|
|
|
|
duckdb.register('models', df) |
|
except Exception as e: |
|
print(f"Error loading dataset: {e}") |
|
raise |
|
|
|
def get_retention_data(start_date: str, end_date: str) -> pd.DataFrame: |
|
try: |
|
|
|
|
|
|
|
query = """ |
|
WITH model_presence AS ( |
|
SELECT |
|
id AS model_id, |
|
collected_at::DATE AS collection_day |
|
FROM models |
|
), |
|
daily_model_counts AS ( |
|
SELECT |
|
collection_day, |
|
COUNT(*) AS total_models_today |
|
FROM model_presence |
|
GROUP BY collection_day |
|
), |
|
retained_models AS ( |
|
SELECT |
|
a.collection_day, |
|
COUNT(*) AS previously_existed_count |
|
FROM model_presence a |
|
JOIN model_presence b |
|
ON a.model_id = b.model_id |
|
AND a.collection_day = b.collection_day + INTERVAL '1 day' |
|
GROUP BY a.collection_day |
|
) |
|
SELECT |
|
d.collection_day, |
|
d.total_models_today, |
|
COALESCE(r.previously_existed_count, 0) AS carried_over_models, |
|
CASE |
|
WHEN d.total_models_today = 0 THEN NULL |
|
ELSE ROUND(COALESCE(r.previously_existed_count, 0) * 100.0 / d.total_models_today, 2) |
|
END AS percent_retained |
|
FROM daily_model_counts d |
|
LEFT JOIN retained_models r ON d.collection_day = r.collection_day |
|
WHERE d.collection_day BETWEEN ? AND ? |
|
ORDER BY d.collection_day |
|
""" |
|
|
|
result = duckdb.query(query, params=[start_date, end_date]).to_df() |
|
print("SQL Query Result:") |
|
print(result) |
|
return result |
|
except Exception as e: |
|
|
|
print(f"Error in get_retention_data: {e}", file=sys.stderr) |
|
|
|
return pd.DataFrame({"Error": [str(e)]}) |
|
|
|
def plot_retention_data(dataframe: pd.DataFrame): |
|
print("DataFrame received by plot_retention_data (first 5 rows):") |
|
print(dataframe.head()) |
|
print("\nData types in plot_retention_data before any conversion:") |
|
print(dataframe.dtypes) |
|
|
|
|
|
if "Error" in dataframe.columns and not dataframe.empty: |
|
error_message = dataframe['Error'].iloc[0] |
|
print(f"Error DataFrame received: {error_message}", file=sys.stderr) |
|
fig = go.Figure() |
|
fig.add_annotation( |
|
text=f"Error from data generation: {error_message}", |
|
xref="paper", yref="paper", |
|
x=0.5, y=0.5, showarrow=False, |
|
font=dict(size=16) |
|
) |
|
return fig |
|
|
|
try: |
|
|
|
if 'percent_retained' not in dataframe.columns: |
|
raise ValueError("'percent_retained' column is missing from the DataFrame.") |
|
|
|
if 'collection_day' not in dataframe.columns: |
|
raise ValueError("'collection_day' column is missing from the DataFrame.") |
|
|
|
|
|
|
|
dataframe['percent_retained'] = pd.to_numeric(dataframe['percent_retained'], errors='coerce') |
|
dataframe['collection_day'] = pd.to_datetime(dataframe['collection_day']) |
|
|
|
|
|
dataframe.dropna(subset=['percent_retained', 'collection_day'], inplace=True) |
|
|
|
print("\n'percent_retained' column after pd.to_numeric (first 5 values):") |
|
print(dataframe['percent_retained'].head()) |
|
print("'percent_retained' dtype after pd.to_numeric:", dataframe['percent_retained'].dtype) |
|
print("\n'collection_day' column after pd.to_datetime (first 5 values):") |
|
print(dataframe['collection_day'].head()) |
|
print("'collection_day' dtype after pd.to_datetime:", dataframe['collection_day'].dtype) |
|
|
|
if dataframe.empty: |
|
fig = go.Figure() |
|
fig.add_annotation( |
|
text="No data available to plot after processing.", |
|
xref="paper", yref="paper", |
|
x=0.5, y=0.5, showarrow=False, |
|
font=dict(size=16) |
|
) |
|
return fig |
|
|
|
|
|
|
|
fig = px.bar( |
|
dataframe, |
|
x='collection_day', |
|
y='percent_retained', |
|
title='Previous Day Top 200 Trending Model Retention %', |
|
labels={'collection_day': 'Date', 'percent_retained': 'Retention Rate (%)'}, |
|
text='percent_retained' |
|
) |
|
|
|
|
|
fig.update_traces( |
|
texttemplate='%{text:.2f}%', |
|
textposition='inside', |
|
insidetextanchor='middle', |
|
textfont_color='white', |
|
textfont_size=10, |
|
hovertemplate='<b>Date</b>: %{x|%Y-%m-%d}<br>' + |
|
'<b>Retention</b>: %{y:.2f}%<extra></extra>' |
|
) |
|
|
|
|
|
if not dataframe['percent_retained'].empty: |
|
average_retention = dataframe['percent_retained'].mean() |
|
fig.add_hline( |
|
y=average_retention, |
|
line_dash="dash", |
|
line_color="red", |
|
annotation_text=f"Average: {average_retention:.2f}%", |
|
annotation_position="bottom right" |
|
) |
|
|
|
fig.update_xaxes(tickangle=45) |
|
fig.update_layout( |
|
title_x=0.5, |
|
xaxis_title="Date", |
|
yaxis_title="Retention Rate (%)", |
|
plot_bgcolor='white', |
|
bargap=0.2 |
|
) |
|
return fig |
|
except Exception as e: |
|
print(f"Error during plot_retention_data: {e}", file=sys.stderr) |
|
fig = go.Figure() |
|
fig.add_annotation( |
|
text=f"Plotting Error: {str(e)}", |
|
xref="paper", yref="paper", |
|
x=0.5, y=0.5, showarrow=False, |
|
font=dict(size=16) |
|
) |
|
return fig |
|
|
|
def interface_fn(start_date, end_date): |
|
result = get_retention_data(start_date, end_date) |
|
return plot_retention_data(result) |
|
|
|
|
|
min_date = datetime.fromisoformat(df['collected_at'].min()).date() |
|
max_date = datetime.fromisoformat(df['collected_at'].max()).date() |
|
|
|
iface = gr.Interface( |
|
fn=interface_fn, |
|
inputs=[ |
|
gr.Textbox(label="Start Date (YYYY-MM-DD)", value=min_date.strftime("%Y-%m-%d")), |
|
gr.Textbox(label="End Date (YYYY-MM-DD)", value=max_date.strftime("%Y-%m-%d")) |
|
], |
|
outputs=gr.Plotly(label="Model Retention Visualization"), |
|
title="Model Retention Analysis", |
|
description="Visualize model retention rates over time. Enter dates in YYYY-MM-DD format." |
|
) |
|
|
|
iface.launch() |