ssyok's picture
first commit app.py and the content
95eab0a
raw
history blame
8.92 kB
import gradio as gr
import pandas as pd
import numpy as np
import joblib
import os
# ==============================================================================
# 1. LOAD MODELS AND SCALER (This part runs once when the script starts)
# ==============================================================================
# Dictionary to hold the loaded model objects and a list of their names
all_models = {}
model_names = [
'Linear Regression', 'Ridge Regression', 'Lasso Regression',
'Random Forest', 'Gradient Boosting'
]
BEST_MODEL_NAME = 'Random Forest' # Define the best model to be highlighted
try:
# Load all the regression models
for name in model_names:
# Construct the filename, e.g., 'models/random_forest.joblib'
filename = f"models/{name.lower().replace(' ', '_')}.joblib"
if os.path.exists(filename):
all_models[name] = joblib.load(filename)
else:
raise FileNotFoundError(f"Model file not found: {filename}")
# Load the scaler
scaler_path = 'models/scaler.joblib'
if os.path.exists(scaler_path):
scaler = joblib.load(scaler_path)
else:
raise FileNotFoundError(f"Scaler file not found: {scaler_path}")
models_loaded = True
print("βœ… All models and scaler loaded successfully!")
# Get the feature names the model was trained on from the scaler
expected_columns = scaler.feature_names_in_
print(f"Models expect {len(expected_columns)} features.")
except Exception as e:
print(f"❌ ERROR: Could not load models. {e}")
print("Please ensure all '.joblib' files are in the 'models/' directory.")
models_loaded = False
all_models = {}
scaler = None
expected_columns = []
# ==============================================================================
# 2. PREDICTION FUNCTION
# ==============================================================================
def predict_shares_all_models(likes, generation_time, gpu_usage, file_size_kb,
width, height, style_accuracy_score,
is_hand_edited, ethical_concerns_flag,
day_of_week, month, hour, platform):
"""
Performs feature engineering, predicts shares using all loaded models,
and returns formatted outputs for the Gradio interface.
"""
if not models_loaded:
error_message = "Models are not loaded. Please check the console for errors."
return 0, error_message, error_message
# --- Step A: Perform feature engineering ---
sample_data = {
'likes': likes,
'style_accuracy_score': style_accuracy_score,
'generation_time': generation_time,
'gpu_usage': gpu_usage,
'file_size_kb': file_size_kb,
'is_hand_edited': int(is_hand_edited),
'ethical_concerns_flag': int(ethical_concerns_flag),
'width': width,
'height': height,
'day_of_week': day_of_week,
'month': month,
'hour': hour
}
sample_data['aspect_ratio'] = width / height if height > 0 else 0
sample_data['total_pixels'] = width * height
sample_data['is_square'] = int(width == height)
sample_data['is_weekend'] = int(day_of_week >= 5)
for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:
sample_data[f'platform_{p}'] = 1 if platform == p else 0
sample_data['engagement_rate'] = likes / (sample_data['total_pixels'] / 1000000 + 1)
sample_data['quality_engagement'] = style_accuracy_score * likes / 100
sample_data['file_density'] = file_size_kb / (sample_data['total_pixels'] / 1000 + 1)
sample_data['gpu_efficiency'] = generation_time / (gpu_usage + 1)
for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:
sample_data[f'{p.lower()}_likes'] = likes * sample_data[f'platform_{p}']
sample_data['month_sin'] = np.sin(2 * np.pi * month / 12)
sample_data['month_cos'] = np.cos(2 * np.pi * month / 12)
sample_data['day_sin'] = np.sin(2 * np.pi * day_of_week / 7)
sample_data['day_cos'] = np.cos(2 * np.pi * day_of_week / 7)
# --- Step B: Align columns and Scale ---
sample_df = pd.DataFrame([sample_data])
sample_df = sample_df.reindex(columns=expected_columns, fill_value=0)
sample_scaled = scaler.transform(sample_df)
# --- Step C: Predict with all models ---
predictions = {}
for name, model in all_models.items():
pred_value = model.predict(sample_scaled)[0]
predictions[name] = max(0, int(pred_value))
# --- Step D: Format the outputs for Gradio ---
# 1. Get the single best model prediction
best_model_prediction = predictions.get(BEST_MODEL_NAME, 0)
# 2. Create a Markdown table for all model predictions
all_results_df = pd.DataFrame(list(predictions.items()), columns=['Model', 'Predicted Shares'])
all_results_df = all_results_df.sort_values('Predicted Shares', ascending=False)
all_models_table = all_results_df.to_markdown(index=False)
# 3. Create a Markdown table for the engineered features
features_df = sample_df.T.reset_index()
features_df.columns = ['Feature', 'Value']
features_df['Value'] = features_df['Value'].apply(lambda x: f"{x:.4f}" if isinstance(x, float) else x)
features_table = features_df.to_markdown(index=False)
return best_model_prediction, all_models_table, features_table
# ==============================================================================
# 3. GRADIO INTERFACE
# ==============================================================================
with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Virality Predictor") as demo:
gr.Markdown("# 🎨 AI Ghibli Image Virality Predictor")
gr.Markdown("Enter image features to get a virality prediction from multiple regression models.")
with gr.Row():
# --- INPUTS COLUMN ---
with gr.Column(scale=2):
gr.Markdown("### 1. Input Features")
with gr.Accordion("Core Engagement & Image Metrics", open=True):
likes = gr.Slider(minimum=0, maximum=10000, value=500, step=10, label="Likes")
style_accuracy_score = gr.Slider(minimum=0, maximum=100, value=85, step=1, label="Style Accuracy Score (%)")
width = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, label="Width (px)")
height = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, label="Height (px)")
file_size_kb = gr.Slider(minimum=100, maximum=5000, value=1500, step=100, label="File Size (KB)")
with gr.Accordion("Technical & Posting Details", open=True):
generation_time = gr.Slider(minimum=1, maximum=30, value=8, step=0.5, label="Generation Time (s)")
gpu_usage = gr.Slider(minimum=10, maximum=100, value=70, step=5, label="GPU Usage (%)")
platform = gr.Radio(["Instagram", "Twitter", "TikTok", "Reddit"], label="Platform", value="Instagram")
day_of_week = gr.Slider(minimum=0, maximum=6, value=4, step=1, label="Day of Week (0=Mon, 6=Sun)")
month = gr.Slider(minimum=1, maximum=12, value=7, step=1, label="Month (1-12)")
hour = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Hour of Day (0-23)")
is_hand_edited = gr.Checkbox(label="Was it Hand Edited?", value=False)
ethical_concerns_flag = gr.Checkbox(label="Any Ethical Concerns?", value=False)
predict_btn = gr.Button("Predict Virality", variant="primary")
# --- OUTPUTS COLUMN ---
with gr.Column(scale=3):
gr.Markdown("### 2. Prediction Results")
# Highlighted Best Model Output
best_model_output = gr.Number(
label=f"πŸ† Best Model Prediction ({BEST_MODEL_NAME})",
interactive=False
)
# Table for All Model Predictions
with gr.Accordion("Comparison of All Models", open=True):
all_models_output = gr.Markdown(label="All Model Predictions")
# Table for Feature Engineering Details
with gr.Accordion("View Engineered Features", open=False):
features_output = gr.Markdown(label="Feature Engineering Details")
# Connect the button to the function
predict_btn.click(
fn=predict_shares_all_models,
inputs=[
likes, generation_time, gpu_usage, file_size_kb,
width, height, style_accuracy_score,
is_hand_edited, ethical_concerns_flag,
day_of_week, month, hour, platform
],
outputs=[
best_model_output,
all_models_output,
features_output
]
)
# Launch the app
if __name__ == "__main__":
if not models_loaded:
print("\nCannot launch Gradio app because models failed to load.")
else:
demo.launch()