Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import io | |
import json | |
import warnings | |
import google.generativeai as genai | |
import os | |
import logging | |
from contextlib import redirect_stdout | |
from sklearn.model_selection import train_test_split | |
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor | |
from sklearn.linear_model import LogisticRegression, LinearRegression | |
from sklearn.metrics import accuracy_score, confusion_matrix, r2_score, mean_squared_error | |
from sklearn.preprocessing import LabelEncoder | |
# --- Configuration --- | |
warnings.filterwarnings('ignore') | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
THEME = gr.themes.Glass(primary_hue="blue", secondary_hue="cyan").set( | |
body_background_fill="rgba(0,0,0,0.8)", | |
block_background_fill="rgba(0,0,0,0.6)", | |
block_border_width="1px", | |
border_color_primary="rgba(255,255,255,0.1)" | |
) | |
MODEL_REGISTRY = { | |
"Classification": {"Random Forest": RandomForestClassifier, "Logistic Regression": LogisticRegression}, | |
"Regression": {"Random Forest": RandomForestRegressor, "Linear Regression": LinearRegression} | |
} | |
# --- Core Logic --- | |
def safe_exec(code_string: str, local_vars: dict) -> tuple: | |
"""Safely execute a string of Python code and capture its output.""" | |
output_buffer = io.StringIO() | |
try: | |
with redirect_stdout(output_buffer): | |
exec(code_string, globals(), local_vars) | |
stdout = output_buffer.getvalue() | |
fig = local_vars.get('fig') | |
df_out = local_vars.get('df_result') | |
return stdout, fig, df_out, None | |
except Exception as e: | |
return None, None, None, f"Execution Error: {str(e)}" | |
def prime_data(file_obj): | |
"""Loads, analyzes, and primes the entire application state upon file upload.""" | |
if not file_obj: | |
return {gr.update(visible=False): None} | |
try: | |
df = pd.read_csv(file_obj.name) | |
# Smart type conversion | |
for col in df.select_dtypes(include=['object']).columns: | |
try: | |
df[col] = pd.to_datetime(df[col], errors='raise') | |
except (ValueError, TypeError): | |
if df[col].nunique() / len(df) < 0.5: # If not too many unique values | |
df[col] = df[col].astype('category') | |
# --- Phoenix Eye: Proactive Insights Engine --- | |
insights = {} | |
metadata = extract_dataset_metadata(df) | |
# 1. Missing Data | |
missing = df.isnull().sum() | |
insights['missing'] = missing[missing > 0].sort_values(ascending=False) | |
# 2. High Cardinality | |
insights['high_cardinality'] = {c: df[c].nunique() for c in metadata['categorical_cols'] if df[c].nunique() > 50} | |
# 3. High Correlations | |
if len(metadata['numeric_cols']) > 1: | |
corr = df[metadata['numeric_cols']].corr().abs() | |
sol = corr.unstack() | |
so = sol.sort_values(kind="quicksort", ascending=False) | |
so = so[so < 1] # Remove self-correlation | |
insights['high_correlations'] = so.head(5) | |
# 4. Outlier Detection (IQR method) | |
outliers = {} | |
for col in metadata['numeric_cols']: | |
Q1, Q3 = df[col].quantile(0.25), df[col].quantile(0.75) | |
IQR = Q3 - Q1 | |
outlier_count = ((df[col] < (Q1 - 1.5 * IQR)) | (df[col] > (Q3 + 1.5 * IQR))).sum() | |
if outlier_count > 0: | |
outliers[col] = outlier_count | |
insights['outliers'] = outliers | |
# 5. ML Target Suggestion | |
suggestions = [] | |
for col in metadata['categorical_cols']: | |
if df[col].nunique() == 2: | |
suggestions.append(f"{col} (Binary Classification)") | |
for col in metadata['numeric_cols']: | |
if df[col].nunique() > 20: # Heuristic for continuous target | |
suggestions.append(f"{col} (Regression)") | |
insights['ml_suggestions'] = suggestions | |
state = { | |
'df_original': df, | |
'df_modified': df.copy(), | |
'filename': os.path.basename(file_obj.name), | |
'metadata': metadata, | |
'proactive_insights': insights | |
} | |
# Generate UI updates | |
overview_md = generate_phoenix_eye_markdown(state) | |
all_cols = metadata['columns'] | |
num_cols = metadata['numeric_cols'] | |
cat_cols = metadata['categorical_cols'] | |
return { | |
global_state: state, | |
phoenix_tabs: gr.update(visible=True), | |
phoenix_eye_output: overview_md, | |
# Data Medic updates | |
medic_col_select: gr.update(choices=insights['missing'].index.tolist() or [], interactive=True), | |
# Oracle updates | |
oracle_target_select: gr.update(choices=all_cols, interactive=True), | |
oracle_feature_select: gr.update(choices=all_cols, interactive=True), | |
} | |
except Exception as e: | |
logging.error(f"Priming Error: {e}") | |
return {phoenix_eye_output: gr.update(value=f"โ **Error:** {e}")} | |
def extract_dataset_metadata(df): | |
"""Extracts typed metadata from a DataFrame.""" | |
rows, cols = df.shape | |
return { | |
'shape': (rows, cols), | |
'columns': df.columns.tolist(), | |
'numeric_cols': df.select_dtypes(include=np.number).columns.tolist(), | |
'categorical_cols': df.select_dtypes(include=['object', 'category']).columns.tolist(), | |
'datetime_cols': df.select_dtypes(include=['datetime64', 'datetime64[ns]']).columns.tolist(), | |
'dtypes': df.dtypes.apply(lambda x: x.name).to_dict() | |
} | |
def generate_phoenix_eye_markdown(state): | |
"""Creates the markdown for the proactive insights dashboard.""" | |
insights = state['proactive_insights'] | |
md = f"## ๐ฆ Phoenix Eye: Proactive Insights for `{state['filename']}`\n" | |
md += f"Dataset has **{state['metadata']['shape'][0]} rows** and **{state['metadata']['shape'][1]} columns**.\n\n" | |
# ML Suggestions | |
md += "### ๐ฎ Potential ML Targets\n" | |
if insights['ml_suggestions']: | |
for s in insights['ml_suggestions']: md += f"- `{s}`\n" | |
else: md += "No obvious ML target columns found.\n" | |
md += "\n" | |
# Missing Data | |
md += "### ๐ง Missing Data\n" | |
if not insights['missing'].empty: | |
md += "Found missing values in these columns. Use the **Data Medic** tab to fix.\n" | |
md += insights['missing'].to_frame('Missing Count').to_markdown() + "\n" | |
else: md += "โ No missing data found!\n" | |
md += "\n" | |
# High Correlation | |
md += "### ๐ Top Correlations\n" | |
if 'high_correlations' in insights and not insights['high_correlations'].empty: | |
md += insights['high_correlations'].to_frame('Correlation').to_markdown() + "\n" | |
else: md += "No strong correlations found between numeric features.\n" | |
md += "\n" | |
# Outliers | |
md += "### ๐ Outlier Alert\n" | |
if insights['outliers']: | |
for col, count in insights['outliers'].items(): md += f"- `{col}` has **{count}** potential outliers.\n" | |
else: md += "โ No significant outliers detected.\n" | |
md += "\n" | |
# High Cardinality | |
md += "### ๐ High Cardinality Warning\n" | |
if insights['high_cardinality']: | |
for col, count in insights['high_cardinality'].items(): md += f"- `{col}` has **{count}** unique values, which may be problematic for some models.\n" | |
else: md += "โ No high-cardinality categorical columns found.\n" | |
md += "\n" | |
return md | |
# --- Tab Handlers --- | |
def medic_preview_imputation(state, col, method): | |
"""Shows a before-and-after plot for data imputation.""" | |
if not col: return None | |
df_orig = state['df_original'] | |
df_mod = df_orig.copy() | |
if method == 'mean': value = df_mod[col].mean() | |
elif method == 'median': value = df_mod[col].median() | |
else: value = df_mod[col].mode()[0] | |
df_mod[col] = df_mod[col].fillna(value) | |
fig = go.Figure() | |
fig.add_trace(go.Histogram(x=df_orig[col], name='Before', opacity=0.7)) | |
fig.add_trace(go.Histogram(x=df_mod[col], name='After', opacity=0.7)) | |
fig.update_layout(barmode='overlay', title=f"'{col}' Distribution: Before vs. After Imputation", legend_title_text='Dataset') | |
return fig | |
def medic_apply_imputation(state, col, method): | |
"""Applies imputation and updates the main state.""" | |
if not col: return state, "No column selected." | |
df_mod = state['df_modified'].copy() | |
if method == 'mean': value = df_mod[col].mean() | |
elif method == 'median': value = df_mod[col].median() | |
else: value = df_mod[col].mode()[0] | |
df_mod[col] = df_mod[col].fillna(value) | |
state['df_modified'] = df_mod | |
# Re-run proactive insights on the modified df | |
state['proactive_insights']['missing'] = df_mod.isnull().sum() | |
state['proactive_insights']['missing'] = state['proactive_insights']['missing'][state['proactive_insights']['missing'] > 0] | |
return state, f"โ Applied '{method}' imputation to '{col}'.", gr.update(choices=state['proactive_insights']['missing'].index.tolist()) | |
def download_cleaned_data(state): | |
"""Saves the modified dataframe to a csv and returns the path.""" | |
if state: | |
df = state['df_modified'] | |
# Gradio handles the tempfile creation | |
return gr.File.update(value=df.to_csv(index=False), visible=True) | |
return gr.File.update(visible=False) | |
def oracle_run_model(state, target, features, model_name): | |
"""Trains a simple ML model and returns metrics and plots.""" | |
if not target or not features: return None, None, "Please select a target and at least one feature." | |
df = state['df_modified'].copy() | |
# Preprocessing | |
df.dropna(subset=features + [target], inplace=True) | |
if df.empty: return None, None, "Not enough data after dropping NA values." | |
le = LabelEncoder() | |
for col in features + [target]: | |
if df[col].dtype == 'object' or df[col].dtype.name == 'category': | |
df[col] = le.fit_transform(df[col]) | |
X = df[features] | |
y = df[target] | |
problem_type = "Classification" if y.nunique() <= 10 else "Regression" | |
if model_name not in MODEL_REGISTRY[problem_type]: | |
return None, None, f"Model {model_name} not suitable for {problem_type}." | |
model = MODEL_REGISTRY[problem_type][model_name](random_state=42) | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) | |
model.fit(X_train, y_train) | |
preds = model.predict(X_test) | |
# Results | |
if problem_type == "Classification": | |
acc = accuracy_score(y_test, preds) | |
cm = confusion_matrix(y_test, preds) | |
cm_fig = px.imshow(cm, text_auto=True, title=f"Confusion Matrix (Accuracy: {acc:.2f})") | |
if hasattr(model, 'feature_importances_'): | |
fi = pd.Series(model.feature_importances_, index=features).sort_values(ascending=False) | |
fi_fig = px.bar(fi, title="Feature Importance") | |
return fi_fig, cm_fig, f"**Classification Report:**\n- Accuracy: {acc:.2f}" | |
else: | |
return None, cm_fig, f"**Classification Report:**\n- Accuracy: {acc:.2f}" | |
else: # Regression | |
r2 = r2_score(y_test, preds) | |
rmse = np.sqrt(mean_squared_error(y_test, preds)) | |
preds_fig = px.scatter(x=y_test, y=preds, labels={'x': 'Actual Values', 'y': 'Predicted Values'}, | |
title=f"Predictions vs. Actuals (Rยฒ: {r2:.2f})", trendline='ols') | |
if hasattr(model, 'feature_importances_'): | |
fi = pd.Series(model.feature_importances_, index=features).sort_values(ascending=False) | |
fi_fig = px.bar(fi, title="Feature Importance") | |
return fi_fig, preds_fig, f"**Regression Report:**\n- Rยฒ Score: {r2:.2f}\n- RMSE: {rmse:.2f}" | |
else: | |
return None, preds_fig, f"**Regression Report:**\n- Rยฒ Score: {r2:.2f}\n- RMSE: {rmse:.2f}" | |
def copilot_respond(user_message, history, state, api_key): | |
"""Handles the AI Co-pilot chat interaction.""" | |
if not api_key: | |
return history + [(user_message, "I need a Gemini API key to function.")], None, None, "" | |
history += [(user_message, None)] | |
prompt = f""" | |
You are 'Phoenix Co-pilot', a world-class AI data analyst. Your goal is to help the user by writing and executing Python code. | |
You have access to a pandas DataFrame named `df`. This is the user's LATEST data, including any cleaning they've performed. | |
**DataFrame Info:** | |
- Columns and dtypes: {json.dumps(state['metadata']['dtypes'])} | |
**Instructions:** | |
1. Analyze the user's request: '{user_message}'. | |
2. Formulate a plan (thought). | |
3. Write Python code to execute the plan. | |
4. Use `pandas`, `numpy`, and `plotly.express as px`. | |
5. To show a plot, assign it to a variable `fig`. Ex: `fig = px.histogram(df, x='age')`. | |
6. To show a dataframe, assign it to a variable `df_result`. Ex: `df_result = df.describe()`. | |
7. Use `print()` for text output. | |
8. **NEVER** modify `df` in place. Use `df.copy()` if needed. | |
9. Respond **ONLY** with a single, valid JSON object with keys "thought" and "code". | |
**User Request:** "{user_message}" | |
**Your JSON Response:** | |
""" | |
try: | |
genai.configure(api_key=api_key) | |
model = genai.GenerativeModel('gemini-1.5-flash') | |
response = model.generate_content(prompt) | |
# Clean and parse JSON | |
response_json = json.loads(response.text.strip().replace("```json", "").replace("```", "")) | |
thought = response_json.get("thought", "Thinking...") | |
code_to_run = response_json.get("code", "print('No code generated.')") | |
bot_thinking = f"๐ง **Thinking:** *{thought}*" | |
history[-1] = (user_message, bot_thinking) | |
yield history, None, None, gr.update(value=code_to_run) | |
# Execute Code | |
local_vars = {'df': state['df_modified'], 'px': px, 'pd': pd, 'np': np} | |
stdout, fig_result, df_result, error = safe_exec(code_to_run, local_vars) | |
bot_response = bot_thinking + "\n\n---\n\n" | |
if error: | |
bot_response += f"๐ฅ **Execution Error:**\n```\n{error}\n```" | |
if stdout: | |
bot_response += f"๐ **Output:**\n```\n{stdout}\n```" | |
if not error and not stdout and not fig_result and not isinstance(df_result, pd.DataFrame): | |
bot_response += "โ Code executed, but produced no direct output." | |
history[-1] = (user_message, bot_response) | |
yield history, fig_result, df_result, gr.update(value=code_to_run) | |
except Exception as e: | |
error_msg = f"A critical error occurred: {e}. The AI may have returned invalid JSON. Check the generated code." | |
history[-1] = (user_message, error_msg) | |
yield history, None, None, "" | |
# --- Gradio UI Construction --- | |
with gr.Blocks(theme=THEME, title="Phoenix AI Data Explorer") as demo: | |
global_state = gr.State({}) | |
gr.Markdown("# ๐ฅ Phoenix AI Data Explorer") | |
gr.Markdown("The next-generation analytic tool. Upload your data to awaken the Phoenix.") | |
with gr.Row(): | |
file_input = gr.File(label="๐ Upload CSV", file_types=[".csv"]) | |
api_key_input = gr.Textbox(label="๐ Gemini API Key", type="password", placeholder="Enter Google AI Studio key...") | |
with gr.Tabs(visible=False) as phoenix_tabs: | |
with gr.Tab("๐ฆ Phoenix Eye"): | |
phoenix_eye_output = gr.Markdown() | |
with gr.Tab("๐ฉบ Data Medic"): | |
gr.Markdown("### Cleanse Your Data\nSelect a column with missing values and choose a method to fill them.") | |
with gr.Row(): | |
medic_col_select = gr.Dropdown(label="Select Column to Clean") | |
medic_method_select = gr.Radio(['mean', 'median', 'mode'], label="Imputation Method", value='mean') | |
medic_preview_btn = gr.Button("๐ Preview Changes") | |
medic_plot = gr.Plot() | |
with gr.Row(): | |
medic_apply_btn = gr.Button("โ Apply & Save Changes", variant="primary") | |
medic_status = gr.Textbox(label="Status", interactive=False) | |
with gr.Accordion("Download Cleaned Data", open=False): | |
download_btn = gr.Button("โฌ๏ธ Download Cleaned CSV") | |
download_file_output = gr.File(label="Download Link", visible=False) | |
with gr.Tab("๐ฎ The Oracle (Predictive Modeling)"): | |
gr.Markdown("### Glimpse the Future\nTrain a simple model to see the predictive power of your data.") | |
with gr.Row(): | |
oracle_target_select = gr.Dropdown(label="๐ฏ Select Target Variable") | |
oracle_feature_select = gr.Multiselect(label="โจ Select Features") | |
oracle_model_select = gr.Dropdown(choices=["Random Forest", "Logistic Regression", "Linear Regression"], label="๐ง Select Model") | |
oracle_run_btn = gr.Button("๐ Train Model!", variant="primary") | |
oracle_status = gr.Markdown() | |
with gr.Row(): | |
oracle_fig1 = gr.Plot() | |
oracle_fig2 = gr.Plot() | |
with gr.Tab("๐ค AI Co-pilot"): | |
gr.Markdown("### Your Conversational Analyst\nAsk any question about your data in plain English.") | |
copilot_chatbot = gr.Chatbot(label="Chat History", height=400) | |
with gr.Accordion("AI Generated Results", open=True): | |
copilot_fig_output = gr.Plot() | |
copilot_df_output = gr.Dataframe(interactive=False) | |
with gr.Accordion("Generated Code", open=False): | |
copilot_code_output = gr.Code(language="python", interactive=False) | |
with gr.Row(): | |
copilot_input = gr.Textbox(label="Your Question", placeholder="e.g., 'What's the correlation between age and salary?'", scale=4) | |
copilot_submit_btn = gr.Button("Submit", variant="primary", scale=1) | |
# --- Event Wiring --- | |
file_input.upload( | |
fn=prime_data, | |
inputs=file_input, | |
outputs=[global_state, phoenix_tabs, phoenix_eye_output, medic_col_select, oracle_target_select, oracle_feature_select], | |
show_progress="full" | |
) | |
# Data Medic | |
medic_preview_btn.click(medic_preview_imputation, [global_state, medic_col_select, medic_method_select], medic_plot) | |
medic_apply_btn.click(medic_apply_imputation, [global_state, medic_col_select, medic_method_select], [global_state, medic_status, medic_col_select]) | |
download_btn.click(download_cleaned_data, [global_state], download_file_output) | |
# Oracle | |
oracle_run_btn.click( | |
oracle_run_model, | |
[global_state, oracle_target_select, oracle_feature_select, oracle_model_select], | |
[oracle_fig1, oracle_fig2, oracle_status], | |
show_progress="full" | |
) | |
# AI Co-pilot | |
copilot_submit_btn.click( | |
copilot_respond, | |
[copilot_input, copilot_chatbot, global_state, api_key_input], | |
[copilot_chatbot, copilot_fig_output, copilot_df_output, copilot_code_output] | |
).then(lambda: "", copilot_input, copilot_input) # Clear input after submit | |
if __name__ == "__main__": | |
demo.launch(debug=True) |