from datasets import load_dataset import pandas as pd import duckdb import matplotlib.pyplot as plt import seaborn as sns # Import Seaborn import plotly.express as px # Added for Plotly import plotly.graph_objects as go # Added for Plotly error figure import gradio as gr import os from huggingface_hub import login from datetime import datetime, timedelta import sys # Added for error logging # Get token from environment variable HF_TOKEN = os.getenv('HF_TOKEN') if not HF_TOKEN: raise ValueError("Please set the HF_TOKEN environment variable") # Login to Hugging Face login(token=HF_TOKEN) # Apply Seaborn theme and context globally sns.set_theme(style="whitegrid") sns.set_context("notebook") # Load dataset once at startup try: dataset = load_dataset("reach-vb/trending-repos", split="models") df = dataset.to_pandas() # Register the pandas DataFrame as a DuckDB table named 'models' # This allows the SQL query to use 'FROM models' duckdb.register('models', df) except Exception as e: print(f"Error loading dataset: {e}") raise def get_retention_data(start_date: str, end_date: str) -> pd.DataFrame: try: # The input start_date and end_date are already strings in YYYY-MM-DD format. # We can pass them directly to DuckDB if the SQL column is DATE. query = """ WITH model_presence AS ( SELECT id AS model_id, collected_at::DATE AS collection_day FROM models ), daily_model_counts AS ( SELECT collection_day, COUNT(*) AS total_models_today FROM model_presence GROUP BY collection_day ), retained_models AS ( SELECT a.collection_day, COUNT(*) AS previously_existed_count FROM model_presence a JOIN model_presence b ON a.model_id = b.model_id AND a.collection_day = b.collection_day + INTERVAL '1 day' GROUP BY a.collection_day ) SELECT d.collection_day, d.total_models_today, COALESCE(r.previously_existed_count, 0) AS carried_over_models, CASE WHEN d.total_models_today = 0 THEN NULL ELSE ROUND(COALESCE(r.previously_existed_count, 0) * 100.0 / d.total_models_today, 2) END AS percent_retained FROM daily_model_counts d LEFT JOIN retained_models r ON d.collection_day = r.collection_day WHERE d.collection_day BETWEEN ? AND ? ORDER BY d.collection_day """ # Pass the string dates directly to the query, using the 'params' keyword argument. result = duckdb.query(query, params=[start_date, end_date]).to_df() print("SQL Query Result:") # Log the result print(result) # Log the result return result except Exception as e: # Log the error to standard error print(f"Error in get_retention_data: {e}", file=sys.stderr) # Return empty DataFrame with error message return pd.DataFrame({"Error": [str(e)]}) def plot_retention_data(dataframe: pd.DataFrame): print("DataFrame received by plot_retention_data (first 5 rows):") print(dataframe.head()) print("\nData types in plot_retention_data before any conversion:") print(dataframe.dtypes) # Check if the DataFrame itself is an error signal from the previous function if "Error" in dataframe.columns and not dataframe.empty: error_message = dataframe['Error'].iloc[0] print(f"Error DataFrame received: {error_message}", file=sys.stderr) fig = go.Figure() fig.add_annotation( text=f"Error from data generation: {error_message}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) ) return fig try: # Ensure 'percent_retained' column exists if 'percent_retained' not in dataframe.columns: raise ValueError("'percent_retained' column is missing from the DataFrame.") if 'collection_day' not in dataframe.columns: raise ValueError("'collection_day' column is missing from the DataFrame.") # Explicitly convert 'percent_retained' to numeric. # Ensure 'percent_retained' is numeric and 'collection_day' is datetime for Plotly dataframe['percent_retained'] = pd.to_numeric(dataframe['percent_retained'], errors='coerce') dataframe['collection_day'] = pd.to_datetime(dataframe['collection_day']) # Drop rows where 'percent_retained' could not be converted (became NaT) dataframe.dropna(subset=['percent_retained', 'collection_day'], inplace=True) print("\n'percent_retained' column after pd.to_numeric (first 5 values):") print(dataframe['percent_retained'].head()) print("'percent_retained' dtype after pd.to_numeric:", dataframe['percent_retained'].dtype) print("\n'collection_day' column after pd.to_datetime (first 5 values):") print(dataframe['collection_day'].head()) print("'collection_day' dtype after pd.to_datetime:", dataframe['collection_day'].dtype) if dataframe.empty: fig = go.Figure() fig.add_annotation( text="No data available to plot after processing.", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) ) return fig # Create Plotly bar chart fig = px.bar( dataframe, x='collection_day', y='percent_retained', title='Previous Day Top 200 Trending Model Retention %', labels={'collection_day': 'Date', 'percent_retained': 'Retention Rate (%)'}, text='percent_retained' # Use the column directly for hover/text ) # Format the text on bars fig.update_traces( texttemplate='%{text:.2f}%', textposition='inside', insidetextanchor='middle', # Anchor text to the middle of the bar textfont_color='white', textfont_size=10, # Adjusted size for better fit hovertemplate='Date: %{x|%Y-%m-%d}
' + 'Retention: %{y:.2f}%' # Custom hover ) # Calculate and plot the average retention line if not dataframe['percent_retained'].empty: average_retention = dataframe['percent_retained'].mean() fig.add_hline( y=average_retention, line_dash="dash", line_color="red", annotation_text=f"Average: {average_retention:.2f}%", annotation_position="bottom right" ) fig.update_xaxes(tickangle=45) fig.update_layout( title_x=0.5, # Center title xaxis_title="Date", yaxis_title="Retention Rate (%)", plot_bgcolor='white', # Set plot background to white like seaborn whitegrid bargap=0.2 # Gap between bars of different categories ) return fig except Exception as e: print(f"Error during plot_retention_data: {e}", file=sys.stderr) fig = go.Figure() fig.add_annotation( text=f"Plotting Error: {str(e)}", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=16) ) return fig def interface_fn(start_date, end_date): result = get_retention_data(start_date, end_date) return plot_retention_data(result) # Get min and max dates from the dataset min_date = datetime.fromisoformat(df['collected_at'].min()).date() max_date = datetime.fromisoformat(df['collected_at'].max()).date() iface = gr.Interface( fn=interface_fn, inputs=[ gr.Textbox(label="Start Date (YYYY-MM-DD)", value=min_date.strftime("%Y-%m-%d")), gr.Textbox(label="End Date (YYYY-MM-DD)", value=max_date.strftime("%Y-%m-%d")) ], outputs=gr.Plot(label="Model Retention Visualization"), title="Model Retention Analysis", description="Visualize model retention rates over time. Enter dates in YYYY-MM-DD format." ) iface.launch()