{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ZBiNdra-AOT2" }, "source": [ "# Import Library" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "taECrFNE9yxz" }, "outputs": [], "source": [ "# Import necessary libraries\n", "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.linear_model import LinearRegression, Ridge, Lasso\n", "from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor\n", "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "markdown", "metadata": { "id": "9MK_JVsjBjM-" }, "source": [ "# Import Dataset" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 417 }, "id": "XJVoE7SgBlDG", "outputId": "f60fab29-501d-4a88-b21b-a752d901790e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset shape: (500, 16)\n", "\n", "First 5 rows:\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
image_iduser_idpromptlikessharescommentsplatformgeneration_timegpu_usagefile_size_kbresolutionstyle_accuracy_scoreis_hand_editedethical_concerns_flagcreation_datetop_comment
077ce5c72-eb45-4651-bcb1-c0677c0fceaf6a7adf3dStudio Ghibli-inspired ocean with giant fish916410555Reddit4.804916841024x102489YesYes2025-03-11So nostalgic, feels like childhood memories. šŸŽ„...
17d66c67f-0d11-4ef9-895c-d865ef11fe40523b8706Ghibli-style village at sunset29651361417Reddit11.118128081024x102492YesNo2025-03-11Absolutely stunning! Love the details. šŸŽØ #5729
2d7978afd-3932-4cce-9a21-5f9bf2bc1f640e02592aA lone traveler exploring an enchanted ruin4727655785Instagram5.564118002048x204861NoNo2025-03-06Is this AI or hand-painted? Incredible! #8001
3cb34636a-a15c-4b15-999c-759dbb8896fe9ed78a42Spirited Away-style bustling market street16291954212TikTok12.45884792048x204876NoNo2025-03-23Is this AI or hand-painted? Incredible! #5620
47511fbb8-db05-4584-a3a4-e8bb525ed58b69ec8f02Magical Ghibli forest with floating lanterns25731281913TikTok4.80641789512x51258NoYes2025-03-06This looks straight out of a Ghibli movie! 🌟 #...
\n", "
" ], "text/plain": [ " image_id user_id \\\n", "0 77ce5c72-eb45-4651-bcb1-c0677c0fceaf 6a7adf3d \n", "1 7d66c67f-0d11-4ef9-895c-d865ef11fe40 523b8706 \n", "2 d7978afd-3932-4cce-9a21-5f9bf2bc1f64 0e02592a \n", "3 cb34636a-a15c-4b15-999c-759dbb8896fe 9ed78a42 \n", "4 7511fbb8-db05-4584-a3a4-e8bb525ed58b 69ec8f02 \n", "\n", " prompt likes shares comments \\\n", "0 Studio Ghibli-inspired ocean with giant fish 916 410 555 \n", "1 Ghibli-style village at sunset 2965 1361 417 \n", "2 A lone traveler exploring an enchanted ruin 4727 655 785 \n", "3 Spirited Away-style bustling market street 1629 1954 212 \n", "4 Magical Ghibli forest with floating lanterns 2573 1281 913 \n", "\n", " platform generation_time gpu_usage file_size_kb resolution \\\n", "0 Reddit 4.80 49 1684 1024x1024 \n", "1 Reddit 11.11 81 2808 1024x1024 \n", "2 Instagram 5.56 41 1800 2048x2048 \n", "3 TikTok 12.45 88 479 2048x2048 \n", "4 TikTok 4.80 64 1789 512x512 \n", "\n", " style_accuracy_score is_hand_edited ethical_concerns_flag creation_date \\\n", "0 89 Yes Yes 2025-03-11 \n", "1 92 Yes No 2025-03-11 \n", "2 61 No No 2025-03-06 \n", "3 76 No No 2025-03-23 \n", "4 58 No Yes 2025-03-06 \n", "\n", " top_comment \n", "0 So nostalgic, feels like childhood memories. šŸŽ„... \n", "1 Absolutely stunning! Love the details. šŸŽØ #5729 \n", "2 Is this AI or hand-painted? Incredible! #8001 \n", "3 Is this AI or hand-painted? Incredible! #5620 \n", "4 This looks straight out of a Ghibli movie! 🌟 #... " ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the dataset\n", "df = pd.read_csv('dataset/ai_ghibli_trend_dataset_v2.csv')\n", "print(f\"Dataset shape: {df.shape}\")\n", "print(\"\\nFirst 5 rows:\")\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "qRbZogQMOBHN" }, "source": [ "# Data Preprocessing and Feature Engineering" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fRPjDRY2OC_T", "outputId": "137bc025-8391-43be-80fe-5ba76ebf6437" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "============================================================\n", "FEATURE ENGINEERING\n", "============================================================\n", "Features created successfully!\n", "Total features: 25\n" ] } ], "source": [ "# Feature Engineering\n", "print(\"=\"*60)\n", "print(\"FEATURE ENGINEERING\")\n", "print(\"=\"*60)\n", "\n", "# Split resolution into width and height\n", "df[['width', 'height']] = df['resolution'].str.split('x', expand=True).astype(int)\n", "\n", "# Convert categorical binary features to numeric\n", "df['is_hand_edited'] = (df['is_hand_edited'] == 'Yes').astype(int)\n", "df['ethical_concerns_flag'] = (df['ethical_concerns_flag'] == 'Yes').astype(int)\n", "\n", "# Extract temporal features\n", "df['creation_date'] = pd.to_datetime(df['creation_date'])\n", "df['day_of_week'] = df['creation_date'].dt.dayofweek\n", "df['month'] = df['creation_date'].dt.month\n", "df['hour'] = df['creation_date'].dt.hour\n", "\n", "# Create derived features\n", "df['aspect_ratio'] = df['width'] / df['height']\n", "df['total_pixels'] = df['width'] * df['height']\n", "df['is_square'] = (df['width'] == df['height']).astype(int)\n", "df['is_weekend'] = (df['day_of_week'] >= 5).astype(int)\n", "\n", "print(f\"Features created successfully!\")\n", "print(f\"Total features: {df.shape[1]}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "mWdj1slFZYSG" }, "source": [ "# Target Variable Analysis" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 606 }, "id": "B-k5NDAAhL1J", "outputId": "30045678-8915-41b6-85c8-0e40cf847f35" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "============================================================\n", "TARGET VARIABLE ANALYSIS\n", "============================================================\n", "\n", "Shares Statistics:\n", "Mean: 1040.18\n", "Median: 1092.00\n", "Std Dev: 562.67\n", "Min: 13\n", "Max: 1999\n", "Skewness: -0.14\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Analyze the target variable\n", "print(\"=\"*60)\n", "print(\"TARGET VARIABLE ANALYSIS\")\n", "print(\"=\"*60)\n", "\n", "# Define target and features\n", "y = df['shares']\n", "X = df.drop(columns=['image_id', 'user_id', 'prompt', 'shares', 'comments',\n", " 'top_comment', 'resolution', 'creation_date'])\n", "\n", "# One-hot encode platform\n", "X = pd.get_dummies(X, columns=['platform'], prefix='platform')\n", "\n", "# Target statistics\n", "print(f\"\\nShares Statistics:\")\n", "print(f\"Mean: {y.mean():.2f}\")\n", "print(f\"Median: {y.median():.2f}\")\n", "print(f\"Std Dev: {y.std():.2f}\")\n", "print(f\"Min: {y.min()}\")\n", "print(f\"Max: {y.max()}\")\n", "print(f\"Skewness: {y.skew():.2f}\")\n", "\n", "# Visualize target distribution\n", "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n", "\n", "ax1.hist(y, bins=50, edgecolor='black', alpha=0.7)\n", "ax1.set_xlabel('Shares')\n", "ax1.set_ylabel('Frequency')\n", "ax1.set_title('Distribution of Shares')\n", "ax1.grid(True, alpha=0.3)\n", "\n", "ax2.boxplot(y, vert=True)\n", "ax2.set_ylabel('Shares')\n", "ax2.set_title('Boxplot of Shares')\n", "ax2.grid(True, alpha=0.3)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "-G8d_ZuXZedt" }, "source": [ "# Feature Analysis and Selection" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "TRbt0BQthNXO", "outputId": "1aabd86c-c700-4a5a-a798-130a73a009a6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "============================================================\n", "FEATURE CORRELATION ANALYSIS\n", "============================================================\n", "Top 15 Feature Correlations with Shares:\n", " feature correlation\n", "19 platform_Twitter -0.113105\n", "16 platform_Instagram 0.070970\n", "13 total_pixels 0.053401\n", "7 width 0.050954\n", "8 height 0.050954\n", "17 platform_Reddit 0.030825\n", "0 likes -0.029318\n", "5 is_hand_edited 0.028240\n", "9 day_of_week 0.024903\n", "3 file_size_kb -0.020748\n", "6 ethical_concerns_flag 0.019647\n", "18 platform_TikTok 0.016843\n", "2 gpu_usage 0.015755\n", "15 is_weekend 0.015367\n", "4 style_accuracy_score 0.011279\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Feature correlation analysis\n", "print(\"=\"*60)\n", "print(\"FEATURE CORRELATION ANALYSIS\")\n", "print(\"=\"*60)\n", "\n", "# Calculate correlations with target\n", "correlations = pd.DataFrame({\n", " 'feature': X.columns,\n", " 'correlation': [X[col].corr(y) for col in X.columns]\n", "}).sort_values('correlation', key=abs, ascending=False)\n", "\n", "print(\"Top 15 Feature Correlations with Shares:\")\n", "print(correlations.head(15))\n", "\n", "# Visualization of correlations\n", "plt.figure(figsize=(10, 8))\n", "top_features = correlations.head(15)\n", "colors = ['green' if x > 0 else 'red' for x in top_features['correlation']]\n", "plt.barh(range(len(top_features)), top_features['correlation'], color=colors, alpha=0.7)\n", "plt.yticks(range(len(top_features)), top_features['feature'])\n", "plt.xlabel('Correlation with Shares')\n", "plt.title('Top 15 Feature Correlations')\n", "plt.axvline(x=0, color='black', linestyle='-', linewidth=0.5)\n", "plt.grid(True, alpha=0.3)\n", "plt.gca().invert_yaxis()\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "Gt1ouI4BZkjE" }, "source": [ "# Advanced Feature Engineering" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "A1KULUuphPON", "outputId": "29ec939b-7ac0-40b9-d88a-f1933a37f153" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "============================================================\n", "CREATING ADVANCED FEATURES\n", "============================================================\n", "Total features after engineering: 31\n", "Features after removing multicollinearity: 29\n" ] } ], "source": [ "# Create interaction and polynomial features\n", "print(\"=\"*60)\n", "print(\"CREATING ADVANCED FEATURES\")\n", "print(\"=\"*60)\n", "\n", "# Log transform target (to handle skewness)\n", "y_log = np.log1p(y) # log1p handles zeros safely\n", "\n", "# Create interaction features\n", "X['engagement_rate'] = X['likes'] / (X['total_pixels'] / 1000000 + 1)\n", "X['quality_engagement'] = X['style_accuracy_score'] * X['likes'] / 100\n", "X['file_density'] = X['file_size_kb'] / (X['total_pixels'] / 1000 + 1)\n", "X['gpu_efficiency'] = X['generation_time'] / (X['gpu_usage'] + 1)\n", "\n", "# Platform-specific features\n", "for platform in ['Twitter', 'TikTok', 'Reddit']:\n", " if f'platform_{platform}' in X.columns:\n", " X[f'{platform.lower()}_likes'] = X['likes'] * X[f'platform_{platform}']\n", "\n", "# Temporal cyclical features\n", "X['month_sin'] = np.sin(2 * np.pi * X['month'] / 12)\n", "X['month_cos'] = np.cos(2 * np.pi * X['month'] / 12)\n", "X['day_sin'] = np.sin(2 * np.pi * X['day_of_week'] / 7)\n", "X['day_cos'] = np.cos(2 * np.pi * X['day_of_week'] / 7)\n", "\n", "print(f\"Total features after engineering: {X.shape[1]}\")\n", "\n", "# Remove highly correlated features\n", "corr_matrix = X.corr().abs()\n", "upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))\n", "to_drop = [column for column in upper_tri.columns if any(upper_tri[column] > 0.95)]\n", "X = X.drop(columns=to_drop)\n", "print(f\"Features after removing multicollinearity: {X.shape[1]}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "x3AkuY7LBlTP" }, "source": [ "# Train-Test Split and Scaling" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7o5a1IBIBoYr", "outputId": "5c5cbae8-6c2a-4b60-ca5a-b420b1a175c2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "============================================================\n", "DATA SPLITTING AND SCALING\n", "============================================================\n", "Training set: (400, 29)\n", "Test set: (100, 29)\n", "Data preprocessing completed!\n" ] } ], "source": [ "# Split the data\n", "print(\"=\"*60)\n", "print(\"DATA SPLITTING AND SCALING\")\n", "print(\"=\"*60)\n", "\n", "# Use both original and log-transformed targets\n", "X_train, X_test, y_train, y_test, y_log_train, y_log_test = train_test_split(\n", " X, y, y_log, test_size=0.2, random_state=42\n", ")\n", "\n", "print(f\"Training set: {X_train.shape}\")\n", "print(f\"Test set: {X_test.shape}\")\n", "\n", "# Scale features\n", "scaler = StandardScaler()\n", "X_train_scaled = scaler.fit_transform(X_train)\n", "X_test_scaled = scaler.transform(X_test)\n", "\n", "print(\"Data preprocessing completed!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "a9p0hQpBBosr" }, "source": [ "# Training Multiple Regression Models" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "YdVJjloPhWrK", "outputId": "64a29ce2-52ea-4049-9966-c46555f76ae9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "============================================================\n", "TRAINING REGRESSION MODELS\n", "============================================================\n", "\n", "1. Linear Regression...\n", "2. Ridge Regression...\n", "3. Lasso Regression...\n", "4. Random Forest Regressor...\n", "5. Gradient Boosting Regressor...\n", "\n", "All models trained successfully!\n" ] } ], "source": [ "# Train multiple regression models\n", "print(\"=\"*60)\n", "print(\"TRAINING REGRESSION MODELS\")\n", "print(\"=\"*60)\n", "\n", "# Dictionary to store results\n", "results = {}\n", "\n", "# Model 1: Linear Regression\n", "print(\"\\n1. Linear Regression...\")\n", "lr = LinearRegression()\n", "lr.fit(X_train_scaled, y_train)\n", "y_pred_lr = lr.predict(X_test_scaled)\n", "results['Linear Regression'] = {\n", " 'predictions': y_pred_lr,\n", " 'r2': r2_score(y_test, y_pred_lr),\n", " 'mae': mean_absolute_error(y_test, y_pred_lr),\n", " 'rmse': np.sqrt(mean_squared_error(y_test, y_pred_lr))\n", "}\n", "\n", "# Model 2: Ridge Regression\n", "print(\"2. Ridge Regression...\")\n", "ridge = Ridge(alpha=10.0)\n", "ridge.fit(X_train_scaled, y_train)\n", "y_pred_ridge = ridge.predict(X_test_scaled)\n", "results['Ridge Regression'] = {\n", " 'predictions': y_pred_ridge,\n", " 'r2': r2_score(y_test, y_pred_ridge),\n", " 'mae': mean_absolute_error(y_test, y_pred_ridge),\n", " 'rmse': np.sqrt(mean_squared_error(y_test, y_pred_ridge))\n", "}\n", "\n", "# Model 3: Lasso Regression\n", "print(\"3. Lasso Regression...\")\n", "lasso = Lasso(alpha=1.0)\n", "lasso.fit(X_train_scaled, y_train)\n", "y_pred_lasso = lasso.predict(X_test_scaled)\n", "results['Lasso Regression'] = {\n", " 'predictions': y_pred_lasso,\n", " 'r2': r2_score(y_test, y_pred_lasso),\n", " 'mae': mean_absolute_error(y_test, y_pred_lasso),\n", " 'rmse': np.sqrt(mean_squared_error(y_test, y_pred_lasso))\n", "}\n", "\n", "# Model 4: Random Forest\n", "print(\"4. Random Forest Regressor...\")\n", "rf = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42)\n", "rf.fit(X_train_scaled, y_train)\n", "y_pred_rf = rf.predict(X_test_scaled)\n", "results['Random Forest'] = {\n", " 'predictions': y_pred_rf,\n", " 'r2': r2_score(y_test, y_pred_rf),\n", " 'mae': mean_absolute_error(y_test, y_pred_rf),\n", " 'rmse': np.sqrt(mean_squared_error(y_test, y_pred_rf))\n", "}\n", "\n", "# Model 5: Gradient Boosting\n", "print(\"5. Gradient Boosting Regressor...\")\n", "gb = GradientBoostingRegressor(n_estimators=100, max_depth=5, random_state=42)\n", "gb.fit(X_train_scaled, y_train)\n", "y_pred_gb = gb.predict(X_test_scaled)\n", "results['Gradient Boosting'] = {\n", " 'predictions': y_pred_gb,\n", " 'r2': r2_score(y_test, y_pred_gb),\n", " 'mae': mean_absolute_error(y_test, y_pred_gb),\n", " 'rmse': np.sqrt(mean_squared_error(y_test, y_pred_gb))\n", "}\n", "\n", "print(\"\\nAll models trained successfully!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "1GgeDPqbZyH0" }, "source": [ "# Model Comparison Table" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yEScyc2GiX86", "outputId": "777ce395-1d4b-4304-bb25-b01ebc6a90c5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "============================================================\n", "MODEL PERFORMANCE COMPARISON\n", "============================================================\n", "\n", "Regression Models Performance:\n", " Model R² Score MAE RMSE\n", " Random Forest -0.085036 518.768500 593.750611\n", " Ridge Regression -0.086767 528.589291 594.223994\n", " Lasso Regression -0.087646 528.573249 594.464161\n", "Linear Regression -0.099142 531.413278 597.597560\n", "Gradient Boosting -0.235788 537.630957 633.656677\n", "\n", "Best performing model: Random Forest\n" ] } ], "source": [ "# Create comparison table\n", "print(\"=\"*60)\n", "print(\"MODEL PERFORMANCE COMPARISON\")\n", "print(\"=\"*60)\n", "\n", "comparison_df = pd.DataFrame({\n", " 'Model': results.keys(),\n", " 'R² Score': [results[m]['r2'] for m in results],\n", " 'MAE': [results[m]['mae'] for m in results],\n", " 'RMSE': [results[m]['rmse'] for m in results]\n", "})\n", "\n", "comparison_df = comparison_df.sort_values('R² Score', ascending=False)\n", "print(\"\\nRegression Models Performance:\")\n", "print(comparison_df.to_string(index=False))\n", "\n", "# Find best model\n", "best_model_name = comparison_df.iloc[0]['Model']\n", "print(f\"\\nBest performing model: {best_model_name}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "DWJr9Ff1Z2Uc" }, "source": [ "# Visualization of Results" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 807 }, "id": "fd_dYczSiwO3", "outputId": "cbe5047f-3d69-474b-c991-5b77ed6a3a26" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Comprehensive visualization\n", "fig, axes = plt.subplots(1, 3, figsize=(18, 8))\n", "\n", "# 1. Model Performance Comparison\n", "ax = axes[0]\n", "x_pos = np.arange(len(comparison_df))\n", "colors = plt.cm.viridis(np.linspace(0, 1, len(comparison_df)))\n", "bars = ax.bar(x_pos, comparison_df['R² Score'], color=colors, alpha=0.8)\n", "ax.set_xlabel('Models')\n", "ax.set_ylabel('R² Score')\n", "ax.set_title('Model R² Score Comparison')\n", "ax.set_xticks(x_pos)\n", "ax.set_xticklabels(comparison_df['Model'], rotation=45, ha='right')\n", "ax.grid(True, alpha=0.3)\n", "ax.axhline(y=0, color='red', linestyle='--', alpha=0.5)\n", "\n", "# Add value labels\n", "for bar, value in zip(bars, comparison_df['R² Score']):\n", " ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,\n", " f'{value:.3f}', ha='center', va='bottom')\n", "\n", "# 2. MAE Comparison\n", "ax = axes[1]\n", "bars = ax.bar(x_pos, comparison_df['MAE'], color=colors, alpha=0.8)\n", "ax.set_xlabel('Models')\n", "ax.set_ylabel('MAE')\n", "ax.set_title('Mean Absolute Error by Model')\n", "ax.set_xticks(x_pos)\n", "ax.set_xticklabels(comparison_df['Model'], rotation=45, ha='right')\n", "ax.grid(True, alpha=0.3)\n", "\n", "# 3. RMSE Comparison (continued)\n", "ax = axes[2]\n", "bars = ax.bar(x_pos, comparison_df['RMSE'], color=colors, alpha=0.8)\n", "ax.set_xlabel('Models')\n", "ax.set_ylabel('RMSE')\n", "ax.set_title('Root Mean Square Error by Model')\n", "ax.set_xticks(x_pos)\n", "ax.set_xticklabels(comparison_df['Model'], rotation=45, ha='right')\n", "ax.grid(True, alpha=0.3)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "QNCcnPHs5WX6" }, "source": [ "# Best Model: Random Forest" ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 718 }, "id": "FU9AabI1vUJF", "outputId": "93da3ca3-03cf-40aa-9f47-d4997f15ae8d" }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Comprehensive visualization\n", "fig, axes = plt.subplots(1, 3, figsize=(25, 8))\n", "\n", "# 4. Best Model: Actual vs Predicted\n", "ax = axes[0]\n", "best_predictions = results[best_model_name]['predictions']\n", "ax.scatter(y_test, best_predictions, alpha=0.5, s=30)\n", "ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)\n", "ax.set_xlabel('Actual Shares')\n", "ax.set_ylabel('Predicted Shares')\n", "ax.set_title(f'{best_model_name}: Actual vs Predicted')\n", "ax.grid(True, alpha=0.3)\n", "\n", "# 5. Residual Plot for Best Model\n", "ax = axes[1]\n", "residuals = y_test - best_predictions\n", "ax.scatter(best_predictions, residuals, alpha=0.5, s=30)\n", "ax.axhline(y=0, color='red', linestyle='--')\n", "ax.set_xlabel('Predicted Values')\n", "ax.set_ylabel('Residuals')\n", "ax.set_title(f'{best_model_name}: Residual Plot')\n", "ax.grid(True, alpha=0.3)\n", "\n", "# 6. Feature Importance (for tree-based models)\n", "ax = axes[2]\n", "if best_model_name in ['Random Forest', 'Gradient Boosting']:\n", " model = rf if best_model_name == 'Random Forest' else gb\n", " feature_importance = pd.DataFrame({\n", " 'feature': X.columns,\n", " 'importance': model.feature_importances_\n", " }).sort_values('importance', ascending=False).head(10)\n", "\n", " y_pos = np.arange(len(feature_importance))\n", " ax.barh(y_pos, feature_importance['importance'], alpha=0.8)\n", " ax.set_yticks(y_pos)\n", " ax.set_yticklabels(feature_importance['feature'])\n", " ax.set_xlabel('Importance')\n", " ax.set_title('Top 10 Feature Importances')\n", " ax.invert_yaxis()\n", "else:\n", " # For linear models, show coefficients\n", " model = lr if best_model_name == 'Linear Regression' else (ridge if best_model_name == 'Ridge Regression' else lasso)\n", " coef_df = pd.DataFrame({\n", " 'feature': X.columns,\n", " 'coefficient': model.coef_\n", " }).sort_values('coefficient', key=abs, ascending=False).head(10)\n", "\n", " y_pos = np.arange(len(coef_df))\n", " colors_coef = ['green' if x > 0 else 'red' for x in coef_df['coefficient']]\n", " ax.barh(y_pos, coef_df['coefficient'], color=colors_coef, alpha=0.7)\n", " ax.set_yticks(y_pos)\n", " ax.set_yticklabels(coef_df['feature'])\n", " ax.set_xlabel('Coefficient Value')\n", " ax.set_title('Top 10 Coefficients by Magnitude')\n", " ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)\n", " ax.invert_yaxis()" ] }, { "cell_type": "markdown", "metadata": { "id": "1TEJk9craU8n" }, "source": [ "# Final Model Insights" ] }, { "cell_type": "markdown", "metadata": { "id": "dX0RUq7KagJw" }, "source": [ "POSSIBLE REASONS FOR MODEL PERFORMANCE:\n", "1. Low feature correlations suggest missing important predictors\n", "2. The relationship between features and shares may be highly non-linear\n", "3. External factors not captured in the dataset may drive virality\n", "4. Possible data quality issues or synthetic data patterns\n", "\n", "RECOMMENDATIONS FOR IMPROVEMENT:\n", "1. Collect additional features (user follower count, posting time, hashtags)\n", "2. Try more advanced models (XGBoost, Neural Networks)\n", "3. Feature engineering focusing on user engagement patterns\n", "4. Consider time-series aspects of virality\n", "5. Investigate outliers and data quality issues" ] }, { "cell_type": "markdown", "metadata": { "id": "Lf-wrybgaish" }, "source": [ "# Save Models" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ax8GI5zaalkF", "outputId": "7ee95963-c3ae-4149-acae-471962674149" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "All models and scaler have been saved successfully!\n" ] } ], "source": [ "import joblib\n", "\n", "# Save all models and the scaler\n", "models_to_save = {\n", " 'linear_regression': lr,\n", " 'ridge_regression': ridge,\n", " 'lasso_regression': lasso,\n", " 'random_forest': rf,\n", " 'gradient_boosting': gb\n", "}\n", "\n", "# Save each model\n", "for model_name, model in models_to_save.items():\n", " joblib.dump(model, f'models/{model_name}.joblib')\n", "\n", "# Save the scaler\n", "joblib.dump(scaler, 'models/scaler.joblib')\n", "\n", "print(\"All models and scaler have been saved successfully!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "IxqiL3sz28by" }, "source": [ "# Prediction Function" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "TIy5twgB48tl", "outputId": "3b6794da-579b-4f2a-ddfe-591b2c4be9d3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded: Linear Regression\n", "Loaded: Ridge Regression\n", "Loaded: Lasso Regression\n", "Loaded: Random Forest\n", "Loaded: Gradient Boosting\n", "Loaded: scaler.joblib\n", "\n", "āœ… All models and scaler loaded successfully!\n", "Model expects 29 features.\n" ] } ], "source": [ "# Dictionary to hold the loaded model objects\n", "all_models = {}\n", "model_names = [\n", " 'Linear Regression', 'Ridge Regression', 'Lasso Regression',\n", " 'Random Forest', 'Gradient Boosting'\n", "]\n", "\n", "try:\n", " # Load all the regression models\n", " for name in model_names:\n", " filename = f\"models/{name.lower().replace(' ', '_')}.joblib\"\n", " all_models[name] = joblib.load(filename)\n", " print(f\"Loaded: {name}\")\n", "\n", " # Load the scaler ONCE, after the loop\n", " scaler = joblib.load('models/scaler.joblib')\n", " print(\"Loaded: scaler.joblib\")\n", "\n", " models_loaded = True\n", " print(\"\\nāœ… All models and scaler loaded successfully!\")\n", "\n", " # Get the feature names the model was trained on from the scaler\n", " expected_columns = scaler.feature_names_in_\n", " print(f\"Model expects {len(expected_columns)} features.\")\n", "\n", "except FileNotFoundError as e:\n", " print(f\"\\nāŒ ERROR: Could not find a model file: {e}\")\n", " print(\"Please make sure all '.joblib' files are in the 'models/' directory.\")\n", " models_loaded = False" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "id": "jHDEssBjarnA" }, "outputs": [], "source": [ "def predict_shares_all_models(likes, generation_time, gpu_usage, file_size_kb,\n", " width, height, style_accuracy_score,\n", " is_hand_edited, ethical_concerns_flag,\n", " day_of_week, month, hour, platform):\n", "\n", " # --- 1. Create a dictionary with the input data ---\n", " sample_data = {\n", " 'likes': likes,\n", " 'style_accuracy_score': style_accuracy_score,\n", " 'generation_time': generation_time,\n", " 'gpu_usage': gpu_usage,\n", " 'file_size_kb': file_size_kb,\n", " 'is_hand_edited': int(is_hand_edited),\n", " 'ethical_concerns_flag': int(ethical_concerns_flag),\n", " 'width': width,\n", " 'height': height,\n", " 'day_of_week': day_of_week,\n", " 'month': month,\n", " 'hour': hour\n", " }\n", "\n", " # --- 2. Perform the same feature engineering as in training ---\n", " # Basic derived features\n", " sample_data['aspect_ratio'] = width / height if height > 0 else 0\n", " sample_data['total_pixels'] = width * height\n", " sample_data['is_square'] = int(width == height)\n", " sample_data['is_weekend'] = int(day_of_week >= 5)\n", "\n", " # One-hot encode platform\n", " for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:\n", " sample_data[f'platform_{p}'] = 1 if platform == p else 0\n", "\n", " # Advanced interaction features\n", " sample_data['engagement_rate'] = likes / (sample_data['total_pixels'] / 1000000 + 1)\n", " sample_data['quality_engagement'] = style_accuracy_score * likes / 100\n", " sample_data['file_density'] = file_size_kb / (sample_data['total_pixels'] / 1000 + 1)\n", " sample_data['gpu_efficiency'] = generation_time / (gpu_usage + 1)\n", "\n", " # Platform-specific likes\n", " for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:\n", " sample_data[f'{p.lower()}_likes'] = likes * sample_data[f'platform_{p}']\n", "\n", " # Temporal cyclical features\n", " sample_data['month_sin'] = np.sin(2 * np.pi * month / 12)\n", " sample_data['month_cos'] = np.cos(2 * np.pi * month / 12)\n", " sample_data['day_sin'] = np.sin(2 * np.pi * day_of_week / 7)\n", " sample_data['day_cos'] = np.cos(2 * np.pi * day_of_week / 7)\n", "\n", " # --- 3. Align columns with the training data ---\n", " # Create a DataFrame and ensure it has the exact same columns in the same order as the training data\n", " sample_df = pd.DataFrame([sample_data])\n", " sample_df = sample_df.reindex(columns=expected_columns, fill_value=0)\n", "\n", " # --- 4. Scale the features ---\n", " try:\n", " sample_scaled = scaler.transform(sample_df)\n", " except Exception as e:\n", " return 0, f\"Error during scaling: {e}\"\n", "\n", " # --- 5. Predict with all models and format output ---\n", " predictions = {}\n", " for name, model in all_models.items():\n", " pred_value = model.predict(sample_scaled)[0]\n", " predictions[name] = max(0, int(pred_value)) # Ensure non-negative and integer\n", "\n", " return predictions" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3qft9Mfu415x", "outputId": "7c6925a7-ce54-4dd4-abb0-fe7a5f395d4e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "==================================================\n", " RUNNING TEST WITH SAMPLE INPUT\n", "==================================================\n", "\n", "--- Input Values ---\n", " likes: 850\n", " generation_time: 7.5\n", " gpu_usage: 88\n", " file_size_kb: 1200\n", " width: 1080\n", " height: 1350\n", " style_accuracy_score: 92\n", " is_hand_edited: True\n", " ethical_concerns_flag: False\n", " day_of_week: 4\n", " month: 6\n", " hour: 19\n", " platform: Instagram\n", "\n", "--- Model Predictions ---\n", " Model Predicted Shares\n", "0 Gradient Boosting 1314\n", "1 Random Forest 1232\n", "2 Lasso Regression 1180\n", "3 Linear Regression 1175\n", "4 Ridge Regression 1169\n", "\n", "āœ… Test complete!\n" ] } ], "source": [ "# Only proceed if the models were loaded correctly\n", "if models_loaded:\n", " print(\"\\n\" + \"=\"*50)\n", " print(\" RUNNING TEST WITH SAMPLE INPUT\")\n", " print(\"=\"*50)\n", "\n", " # 1. Define a dictionary with sample values for a hypothetical image\n", " test_input = {\n", " \"likes\": 850,\n", " \"generation_time\": 7.5,\n", " \"gpu_usage\": 88,\n", " \"file_size_kb\": 1200,\n", " \"width\": 1080,\n", " \"height\": 1350, # Portrait aspect ratio\n", " \"style_accuracy_score\": 92,\n", " \"is_hand_edited\": True,\n", " \"ethical_concerns_flag\": False,\n", " \"day_of_week\": 4, # Friday\n", " \"month\": 6, # June\n", " \"hour\": 19, # 7 PM\n", " \"platform\": \"Instagram\"\n", " }\n", "\n", " # 2. Call your function using the test input dictionary\n", " # The ** operator unpacks the dictionary into keyword arguments\n", " all_predictions = predict_shares_all_models(**test_input)\n", "\n", " # 3. Display the results in a clean, readable format\n", " print(\"\\n--- Input Values ---\")\n", " for key, value in test_input.items():\n", " print(f\"{key:>25}: {value}\")\n", "\n", " print(\"\\n--- Model Predictions ---\")\n", " # Convert the results to a pandas DataFrame for nice printing\n", " results_df = pd.DataFrame(list(all_predictions.items()), columns=['Model', 'Predicted Shares'])\n", " results_df = results_df.sort_values('Predicted Shares', ascending=False).reset_index(drop=True)\n", "\n", " print(results_df.to_string())\n", "\n", " print(\"\\nāœ… Test complete!\")" ] }, { "cell_type": "markdown", "metadata": { "id": "ONkgTSKAbOZu" }, "source": [ "# Export Results" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Pehur_uDbPsF", "outputId": "e91160db-89b1-4b42-de22-415da079c50c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Results exported successfully!\n", "\n", "Files created:\n", "- regression_analysis_results.json\n", "- model_comparison.csv\n" ] } ], "source": [ "# Save all results for presentation\n", "import json\n", "\n", "# Prepare results dictionary\n", "presentation_results = {\n", " 'dataset_info': {\n", " 'total_samples': len(df),\n", " 'features': X.shape[1],\n", " 'target_mean': float(y.mean()),\n", " 'target_median': float(y.median()),\n", " 'target_std': float(y.std())\n", " },\n", " 'model_comparison': comparison_df.to_dict('records'),\n", " 'feature_correlations': correlations.head(10).to_dict('records')\n", "}\n", "\n", "# Save to JSON\n", "with open('results/regression_analysis_results.json', 'w') as f:\n", " json.dump(presentation_results, f, indent=2)\n", "\n", "# Save comparison table as CSV\n", "comparison_df.to_csv('results/model_comparison.csv', index=False)\n", "\n", "print(\"Results exported successfully!\")\n", "print(\"\\nFiles created:\")\n", "print(\"- regression_analysis_results.json\")\n", "print(\"- model_comparison.csv\")" ] }, { "cell_type": "markdown", "metadata": { "id": "gS9BH1XudOXW" }, "source": [ "# Gradio Demo App" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 629 }, "id": "xVXYa7sNaxF8", "outputId": "5b726664-dbd2-4206-ead7-782886664901" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "āœ… All models and scaler loaded successfully!\n", "Models expect 29 features.\n", "* Running on local URL: http://127.0.0.1:7860\n", "* Running on public URL: https://c56406accffdd945a3.gradio.live\n", "\n", "This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Traceback (most recent call last):\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/pandas/compat/_optional.py\", line 135, in import_optional_dependency\n", " module = importlib.import_module(name)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/usr/lib/python3.12/importlib/__init__.py\", line 90, in import_module\n", " return _bootstrap._gcd_import(name[level:], package, level)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"\", line 1387, in _gcd_import\n", " File \"\", line 1360, in _find_and_load\n", " File \"\", line 1324, in _find_and_load_unlocked\n", "ModuleNotFoundError: No module named 'tabulate'\n", "\n", "During handling of the above exception, another exception occurred:\n", "\n", "Traceback (most recent call last):\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/gradio/queueing.py\", line 625, in process_events\n", " response = await route_utils.call_process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n", " output = await app.get_blocks().process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/gradio/blocks.py\", line 2218, in process_api\n", " result = await self.call_function(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/gradio/blocks.py\", line 1729, in call_function\n", " prediction = await anyio.to_thread.run_sync( # type: ignore\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/anyio/to_thread.py\", line 56, in run_sync\n", " return await get_async_backend().run_sync_in_worker_thread(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 2470, in run_sync_in_worker_thread\n", " return await future\n", " ^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 967, in run\n", " result = context.run(func, *args)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/gradio/utils.py\", line 894, in wrapper\n", " response = f(*args, **kwargs)\n", " ^^^^^^^^^^^^^^^^^^\n", " File \"/tmp/ipykernel_70924/2141717853.py\", line 123, in predict_shares_all_models\n", " all_models_table = all_results_df.to_markdown(index=False)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/pandas/util/_decorators.py\", line 333, in wrapper\n", " return func(*args, **kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/pandas/core/frame.py\", line 2988, in to_markdown\n", " tabulate = import_optional_dependency(\"tabulate\")\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/pandas/compat/_optional.py\", line 138, in import_optional_dependency\n", " raise ImportError(msg)\n", "ImportError: Missing optional dependency 'tabulate'. Use pip or conda to install tabulate.\n", "Traceback (most recent call last):\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/pandas/compat/_optional.py\", line 135, in import_optional_dependency\n", " module = importlib.import_module(name)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/usr/lib/python3.12/importlib/__init__.py\", line 90, in import_module\n", " return _bootstrap._gcd_import(name[level:], package, level)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"\", line 1387, in _gcd_import\n", " File \"\", line 1360, in _find_and_load\n", " File \"\", line 1324, in _find_and_load_unlocked\n", "ModuleNotFoundError: No module named 'tabulate'\n", "\n", "During handling of the above exception, another exception occurred:\n", "\n", "Traceback (most recent call last):\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/gradio/queueing.py\", line 625, in process_events\n", " response = await route_utils.call_process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/gradio/route_utils.py\", line 322, in call_process_api\n", " output = await app.get_blocks().process_api(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/gradio/blocks.py\", line 2218, in process_api\n", " result = await self.call_function(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/gradio/blocks.py\", line 1729, in call_function\n", " prediction = await anyio.to_thread.run_sync( # type: ignore\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/anyio/to_thread.py\", line 56, in run_sync\n", " return await get_async_backend().run_sync_in_worker_thread(\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 2470, in run_sync_in_worker_thread\n", " return await future\n", " ^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/anyio/_backends/_asyncio.py\", line 967, in run\n", " result = context.run(func, *args)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/gradio/utils.py\", line 894, in wrapper\n", " response = f(*args, **kwargs)\n", " ^^^^^^^^^^^^^^^^^^\n", " File \"/tmp/ipykernel_70924/2141717853.py\", line 123, in predict_shares_all_models\n", " all_models_table = all_results_df.to_markdown(index=False)\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/pandas/util/_decorators.py\", line 333, in wrapper\n", " return func(*args, **kwargs)\n", " ^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/pandas/core/frame.py\", line 2988, in to_markdown\n", " tabulate = import_optional_dependency(\"tabulate\")\n", " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " File \"/home/ssyok/Documents/UM/GFW0003 Data Analytics /Group Assignment/AI-Ghibli-Image-Virality-Predictor/venv/lib/python3.12/site-packages/pandas/compat/_optional.py\", line 138, in import_optional_dependency\n", " raise ImportError(msg)\n", "ImportError: Missing optional dependency 'tabulate'. Use pip or conda to install tabulate.\n" ] } ], "source": [ "import gradio as gr\n", "import pandas as pd\n", "import numpy as np\n", "import joblib\n", "import os\n", "\n", "# ==============================================================================\n", "# 1. LOAD MODELS AND SCALER (This part runs once when the script starts)\n", "# ==============================================================================\n", "\n", "# Dictionary to hold the loaded model objects and a list of their names\n", "all_models = {}\n", "model_names = [\n", " 'Linear Regression', 'Ridge Regression', 'Lasso Regression',\n", " 'Random Forest', 'Gradient Boosting'\n", "]\n", "BEST_MODEL_NAME = 'Random Forest' # Define the best model to be highlighted\n", "\n", "try:\n", " # Load all the regression models\n", " for name in model_names:\n", " # Construct the filename, e.g., 'models/random_forest.joblib'\n", " filename = f\"models/{name.lower().replace(' ', '_')}.joblib\"\n", " if os.path.exists(filename):\n", " all_models[name] = joblib.load(filename)\n", " else:\n", " raise FileNotFoundError(f\"Model file not found: {filename}\")\n", "\n", " # Load the scaler\n", " scaler_path = 'models/scaler.joblib'\n", " if os.path.exists(scaler_path):\n", " scaler = joblib.load(scaler_path)\n", " else:\n", " raise FileNotFoundError(f\"Scaler file not found: {scaler_path}\")\n", "\n", " models_loaded = True\n", " print(\"āœ… All models and scaler loaded successfully!\")\n", "\n", " # Get the feature names the model was trained on from the scaler\n", " expected_columns = scaler.feature_names_in_\n", " print(f\"Models expect {len(expected_columns)} features.\")\n", "\n", "except Exception as e:\n", " print(f\"āŒ ERROR: Could not load models. {e}\")\n", " print(\"Please ensure all '.joblib' files are in the 'models/' directory.\")\n", " models_loaded = False\n", " all_models = {}\n", " scaler = None\n", " expected_columns = []\n", "\n", "# ==============================================================================\n", "# 2. PREDICTION FUNCTION\n", "# ==============================================================================\n", "\n", "def predict_shares_all_models(likes, generation_time, gpu_usage, file_size_kb,\n", " width, height, style_accuracy_score,\n", " is_hand_edited, ethical_concerns_flag,\n", " day_of_week, month, hour, platform):\n", " \"\"\"\n", " Performs feature engineering, predicts shares using all loaded models,\n", " and returns formatted outputs for the Gradio interface.\n", " \"\"\"\n", " if not models_loaded:\n", " error_message = \"Models are not loaded. Please check the console for errors.\"\n", " return 0, error_message, error_message\n", "\n", " # --- Step A: Perform feature engineering ---\n", " sample_data = {\n", " 'likes': likes,\n", " 'style_accuracy_score': style_accuracy_score,\n", " 'generation_time': generation_time,\n", " 'gpu_usage': gpu_usage,\n", " 'file_size_kb': file_size_kb,\n", " 'is_hand_edited': int(is_hand_edited),\n", " 'ethical_concerns_flag': int(ethical_concerns_flag),\n", " 'width': width,\n", " 'height': height,\n", " 'day_of_week': day_of_week,\n", " 'month': month,\n", " 'hour': hour\n", " }\n", "\n", " sample_data['aspect_ratio'] = width / height if height > 0 else 0\n", " sample_data['total_pixels'] = width * height\n", " sample_data['is_square'] = int(width == height)\n", " sample_data['is_weekend'] = int(day_of_week >= 5)\n", "\n", " for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:\n", " sample_data[f'platform_{p}'] = 1 if platform == p else 0\n", "\n", " sample_data['engagement_rate'] = likes / (sample_data['total_pixels'] / 1000000 + 1)\n", " sample_data['quality_engagement'] = style_accuracy_score * likes / 100\n", " sample_data['file_density'] = file_size_kb / (sample_data['total_pixels'] / 1000 + 1)\n", " sample_data['gpu_efficiency'] = generation_time / (gpu_usage + 1)\n", "\n", " for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:\n", " sample_data[f'{p.lower()}_likes'] = likes * sample_data[f'platform_{p}']\n", "\n", " sample_data['month_sin'] = np.sin(2 * np.pi * month / 12)\n", " sample_data['month_cos'] = np.cos(2 * np.pi * month / 12)\n", " sample_data['day_sin'] = np.sin(2 * np.pi * day_of_week / 7)\n", " sample_data['day_cos'] = np.cos(2 * np.pi * day_of_week / 7)\n", "\n", " # --- Step B: Align columns and Scale ---\n", " sample_df = pd.DataFrame([sample_data])\n", " sample_df = sample_df.reindex(columns=expected_columns, fill_value=0)\n", " sample_scaled = scaler.transform(sample_df)\n", "\n", " # --- Step C: Predict with all models ---\n", " predictions = {}\n", " for name, model in all_models.items():\n", " pred_value = model.predict(sample_scaled)[0]\n", " predictions[name] = max(0, int(pred_value))\n", "\n", " # --- Step D: Format the outputs for Gradio ---\n", "\n", " # 1. Get the single best model prediction\n", " best_model_prediction = predictions.get(BEST_MODEL_NAME, 0)\n", "\n", " # 2. Create a Markdown table for all model predictions\n", " all_results_df = pd.DataFrame(list(predictions.items()), columns=['Model', 'Predicted Shares'])\n", " all_results_df = all_results_df.sort_values('Predicted Shares', ascending=False)\n", " all_models_table = all_results_df.to_markdown(index=False)\n", "\n", " # 3. Create a Markdown table for the engineered features\n", " features_df = sample_df.T.reset_index()\n", " features_df.columns = ['Feature', 'Value']\n", " features_df['Value'] = features_df['Value'].apply(lambda x: f\"{x:.4f}\" if isinstance(x, float) else x)\n", " features_table = features_df.to_markdown(index=False)\n", "\n", " return best_model_prediction, all_models_table, features_table\n", "\n", "# ==============================================================================\n", "# 3. GRADIO INTERFACE\n", "# ==============================================================================\n", "\n", "with gr.Blocks(theme=gr.themes.Soft(), title=\"AI Image Virality Predictor\") as demo:\n", " gr.Markdown(\"# šŸŽØ AI Ghibli Image Virality Predictor\")\n", " gr.Markdown(\"Enter image features to get a virality prediction from multiple regression models.\")\n", "\n", " with gr.Row():\n", " # --- INPUTS COLUMN ---\n", " with gr.Column(scale=2):\n", " gr.Markdown(\"### 1. Input Features\")\n", " with gr.Accordion(\"Core Engagement & Image Metrics\", open=True):\n", " likes = gr.Slider(minimum=0, maximum=10000, value=500, step=10, label=\"Likes\")\n", " style_accuracy_score = gr.Slider(minimum=0, maximum=100, value=85, step=1, label=\"Style Accuracy Score (%)\")\n", " width = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, label=\"Width (px)\")\n", " height = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, label=\"Height (px)\")\n", " file_size_kb = gr.Slider(minimum=100, maximum=5000, value=1500, step=100, label=\"File Size (KB)\")\n", "\n", " with gr.Accordion(\"Technical & Posting Details\", open=True):\n", " generation_time = gr.Slider(minimum=1, maximum=30, value=8, step=0.5, label=\"Generation Time (s)\")\n", " gpu_usage = gr.Slider(minimum=10, maximum=100, value=70, step=5, label=\"GPU Usage (%)\")\n", " platform = gr.Radio([\"Instagram\", \"Twitter\", \"TikTok\", \"Reddit\"], label=\"Platform\", value=\"Instagram\")\n", " day_of_week = gr.Slider(minimum=0, maximum=6, value=4, step=1, label=\"Day of Week (0=Mon, 6=Sun)\")\n", " month = gr.Slider(minimum=1, maximum=12, value=7, step=1, label=\"Month (1-12)\")\n", " hour = gr.Slider(minimum=0, maximum=23, value=18, step=1, label=\"Hour of Day (0-23)\")\n", " is_hand_edited = gr.Checkbox(label=\"Was it Hand Edited?\", value=False)\n", " ethical_concerns_flag = gr.Checkbox(label=\"Any Ethical Concerns?\", value=False)\n", "\n", " predict_btn = gr.Button(\"Predict Virality\", variant=\"primary\")\n", "\n", " # --- OUTPUTS COLUMN ---\n", " with gr.Column(scale=3):\n", " gr.Markdown(\"### 2. Prediction Results\")\n", "\n", " # Highlighted Best Model Output\n", " best_model_output = gr.Number(\n", " label=f\"šŸ† Best Model Prediction ({BEST_MODEL_NAME})\",\n", " interactive=False\n", " )\n", "\n", " # Table for All Model Predictions\n", " with gr.Accordion(\"Comparison of All Models\", open=True):\n", " all_models_output = gr.Markdown(label=\"All Model Predictions\")\n", "\n", " # Table for Feature Engineering Details\n", " with gr.Accordion(\"View Engineered Features\", open=False):\n", " features_output = gr.Markdown(label=\"Feature Engineering Details\")\n", "\n", " # Connect the button to the function\n", " predict_btn.click(\n", " fn=predict_shares_all_models,\n", " inputs=[\n", " likes, generation_time, gpu_usage, file_size_kb,\n", " width, height, style_accuracy_score,\n", " is_hand_edited, ethical_concerns_flag,\n", " day_of_week, month, hour, platform\n", " ],\n", " outputs=[\n", " best_model_output,\n", " all_models_output,\n", " features_output\n", " ]\n", " )\n", "\n", "# Launch the app\n", "if __name__ == \"__main__\":\n", " if not models_loaded:\n", " print(\"\\nCannot launch Gradio app because models failed to load.\")\n", " else:\n", " demo.launch(share=True)" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }