Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# | |
# PROJECT: CognitiveEDA - The Adaptive Intelligence Engine | |
# | |
# DESCRIPTION: A world-class data discovery platform that provides a complete suite | |
# of standard EDA tools and intelligently unlocks specialized analysis | |
# modules for Time-Series, Text, and Clustering, offering a truly | |
# comprehensive and context-aware analytical experience. | |
# | |
# SETUP: $ pip install -r requirements.txt | |
# | |
# AUTHOR: An MCP Expert in Data & AI Solutions | |
# VERSION: 4.2 (Bugfix Edition: AI Narrative Engine Restored) | |
# LAST-UPDATE: 2023-10-29 (Fixed critical bug where AI was not being called) | |
from __future__ import annotations | |
import warnings | |
import logging | |
import os | |
from datetime import datetime | |
from typing import Any, Dict, Optional, Tuple | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import google.generativeai as genai | |
# --- Local Adaptive Modules (Requires analysis_modules.py and requirements.txt) --- | |
from analysis_modules import analyze_time_series, generate_word_cloud, perform_clustering | |
# --- Configuration & Setup --- | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - [%(levelname)s] - (%(filename)s:%(lineno)d) - %(message)s') | |
warnings.filterwarnings('ignore', category=FutureWarning) | |
class Config: | |
APP_TITLE = "π CognitiveEDA: The Adaptive Intelligence Engine" | |
GEMINI_MODEL = 'gemini-1.5-flash-latest' | |
MAX_UI_ROWS = 50000 | |
# --- Core Analysis Engine --- | |
class DataAnalyzer: | |
"""The complete DataAnalyzer class, now with a fully functional AI engine.""" | |
def __init__(self, df: pd.DataFrame): | |
if not isinstance(df, pd.DataFrame): raise TypeError("Input must be a pandas DataFrame.") | |
self.df = df | |
self._metadata: Optional[Dict[str, Any]] = None | |
logging.info(f"DataAnalyzer instantiated with DataFrame of shape: {self.df.shape}") | |
def metadata(self) -> Dict[str, Any]: | |
if self._metadata is None: self._metadata = self._extract_metadata() | |
return self._metadata | |
def _extract_metadata(self) -> Dict[str, Any]: | |
rows, cols = self.df.shape | |
numeric_cols = self.df.select_dtypes(include=np.number).columns.tolist() | |
categorical_cols = self.df.select_dtypes(include=['object', 'category']).columns.tolist() | |
datetime_cols = self.df.select_dtypes(include=['datetime64', 'datetimetz']).columns.tolist() | |
text_cols = [col for col in categorical_cols if self.df[col].dropna().str.len().mean() > 50] | |
high_corr_pairs = [] | |
if len(numeric_cols) > 1: | |
corr_matrix = self.df[numeric_cols].corr().abs() | |
upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool)) | |
high_corr_series = upper_tri.stack() | |
high_corr_pairs = (high_corr_series[high_corr_series > 0.75].reset_index().rename(columns={'level_0': 'Feature 1', 'level_1': 'Feature 2', 0: 'Correlation'}).to_dict('records')) | |
return { | |
'shape': (rows, cols), 'columns': self.df.columns.tolist(), | |
'numeric_cols': numeric_cols, 'categorical_cols': categorical_cols, | |
'datetime_cols': datetime_cols, 'text_cols': text_cols, | |
'memory_usage_mb': f"{self.df.memory_usage(deep=True).sum() / 1e6:.2f}", | |
'total_missing': int(self.df.isnull().sum().sum()), | |
'data_quality_score': round((self.df.notna().sum().sum() / self.df.size) * 100, 2), | |
'high_corr_pairs': high_corr_pairs, | |
} | |
def get_profiling_tables(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: | |
missing = self.df.isnull().sum() | |
missing_df = pd.DataFrame({'Missing Count': missing, 'Missing Percentage (%)': (missing / len(self.df) * 100).round(2)}).reset_index().rename(columns={'index': 'Column'}).sort_values('Missing Count', ascending=False) | |
numeric_stats = self.df[self.metadata['numeric_cols']].describe(percentiles=[.01, .25, .5, .75, .99]).T | |
numeric_stats_df = numeric_stats.round(3).reset_index().rename(columns={'index': 'Column'}) | |
cat_stats = self.df[self.metadata['categorical_cols']].describe(include=['object', 'category']).T | |
cat_stats_df = cat_stats.reset_index().rename(columns={'index': 'Column'}) | |
return missing_df, numeric_stats_df, cat_stats_df | |
def get_overview_visuals(self) -> Tuple[go.Figure, go.Figure, go.Figure]: | |
meta = self.metadata | |
dtype_counts = self.df.dtypes.astype(str).value_counts() | |
fig_types = px.pie(values=dtype_counts.values, names=dtype_counts.index, title="<b>π Data Type Composition</b>", hole=0.4, color_discrete_sequence=px.colors.qualitative.Pastel) | |
missing_df = self.df.isnull().sum().reset_index(name='count').query('count > 0') | |
fig_missing = px.bar(missing_df, x='index', y='count', title="<b>π³οΈ Missing Values Distribution</b>", labels={'index': 'Column Name', 'count': 'Number of Missing Values'}).update_xaxes(categoryorder="total descending") | |
fig_corr = go.Figure() | |
if len(meta['numeric_cols']) > 1: | |
corr_matrix = self.df[meta['numeric_cols']].corr() | |
fig_corr = px.imshow(corr_matrix, text_auto=".2f", aspect="auto", title="<b>π Correlation Matrix</b>", color_continuous_scale='RdBu_r', zmin=-1, zmax=1) | |
return fig_types, fig_missing, fig_corr | |
def generate_ai_narrative(self, api_key: str, context: Dict[str, Any]) -> str: | |
"""Generates a context-aware AI narrative using the Gemini API.""" | |
logging.info(f"Generating AI narrative with context: {list(context.keys())}") | |
meta = self.metadata | |
data_snippet_md = self.df.head(5).to_markdown(index=False) | |
context_prompt = "**PRIMARY ANALYSIS MODES:**\n" | |
if context.get('is_timeseries'): | |
context_prompt += "- **Time-Series Detected:** Focus on trends, seasonality, and stationarity. Suggest forecasting models.\n" | |
if context.get('has_text'): | |
context_prompt += "- **Long-Form Text Detected:** Note potential for NLP tasks like sentiment analysis or topic modeling.\n" | |
if not context.get('is_timeseries') and not context.get('has_text'): | |
context_prompt += "- **General Tabular Data:** Focus on distributions, correlations, and potential for classification/regression modeling.\n" | |
prompt = f""" | |
As "Cognitive Analyst," an elite AI data scientist, your task is to generate a comprehensive, multi-part data discovery report in Markdown format. | |
{context_prompt} | |
**DATASET METADATA:** | |
- **Shape:** {meta['shape'][0]} rows, {meta['shape'][1]} columns. | |
- **Data Quality Score:** {meta['data_quality_score']}% | |
- **Total Missing Values:** {meta['total_missing']:,} | |
- **Highly Correlated Pairs:** {meta['high_corr_pairs'] if meta['high_corr_pairs'] else 'None detected.'} | |
- **Data Snippet (First 5 Rows):** | |
{data_snippet_md} | |
**REQUIRED REPORT STRUCTURE:** | |
# π AI Data Discovery Report | |
## π 1. Executive Summary | |
* **Primary Objective:** (Deduce the likely purpose of this dataset. What problem could it solve?) | |
* **Key Finding:** (State the single most interesting insight you've discovered.) | |
* **Overall State:** (Briefly comment on the data's quality and readiness for analysis.) | |
## π§ 2. Deep Dive & Quality Assessment | |
* **Structural Profile:** (Describe the dataset's composition: numeric, categorical, text, time-series features.) | |
* **Data Quality Audit:** (Elaborate on the quality score and missing values. Are they a major concern?) | |
* **Redundancy Check:** (Comment on the detected high-correlation pairs and any risks.) | |
## π‘ 3. Actionable Recommendations | |
* **Data Cleaning:** (Provide a specific recommendation for handling missing data or outliers.) | |
* **Feature Engineering:** (Suggest creating a new, valuable feature.) | |
* **Next Analytical Steps:** (Propose a specific hypothesis to test or a suitable ML model to build.) | |
""" | |
try: | |
genai.configure(api_key=api_key) | |
model = genai.GenerativeModel(Config.GEMINI_MODEL) | |
response = model.generate_content(prompt) | |
if not response.parts: | |
blocked_reason = response.prompt_feedback.block_reason.name if response.prompt_feedback else "Unknown" | |
logging.warning(f"AI response blocked. Reason: {blocked_reason}") | |
return f"β **AI Report Generation Blocked by Safety Settings**\n**Reason:** `{blocked_reason}`." | |
return response.text | |
except Exception as e: | |
logging.error(f"Gemini API call failed: {e}", exc_info=True) | |
return f"β **AI Report Generation Failed**\n**Error:** `{str(e)}`" | |
# --- UI Creation --- | |
def create_ui(): | |
"""Defines the complete, integrated Gradio user interface.""" | |
def create_histogram(analyzer: DataAnalyzer, col: str) -> go.Figure: | |
if not col or not analyzer: return go.Figure() | |
return px.histogram(analyzer.df, x=col, title=f"<b>Distribution of {col}</b>", marginal="box", template="plotly_white") | |
def create_scatterplot(analyzer: DataAnalyzer, x_col: str, y_col:str, color_col:str) -> go.Figure: | |
if not all([analyzer, x_col, y_col]): return go.Figure() | |
return px.scatter(analyzer.df, x=x_col, y=y_col, color=color_col, title=f"<b>Scatter Plot: {x_col} vs. {y_col}</b>", template="plotly_white") | |
def analyze_single_column(analyzer: DataAnalyzer, col: str) -> Tuple[str, go.Figure]: | |
if not col or not analyzer: return "", go.Figure() | |
series = analyzer.df[col] | |
stats_md = f"### π **Deep Dive: `{col}`**\n- **Data Type:** `{series.dtype}`\n- **Unique Values:** `{series.nunique()}`\n- **Missing:** `{series.isnull().sum()}` ({series.isnull().mean():.2%})\n" | |
if pd.api.types.is_numeric_dtype(series): | |
stats_md += f"- **Mean:** `{series.mean():.3f}` | **Median:** `{series.median():.3f}` | **Std Dev:** `{series.std():.3f}`" | |
fig = create_histogram(analyzer, col) | |
else: | |
stats_md += f"- **Top Value:** `{series.value_counts().index[0]}`" | |
top_n = series.value_counts().nlargest(10) | |
fig = px.bar(top_n, y=top_n.index, x=top_n.values, orientation='h', title=f"<b>Top 10 Categories in `{col}`</b>").update_yaxes(categoryorder="total ascending") | |
return stats_md, fig | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue"), title=Config.APP_TITLE) as demo: | |
state_analyzer = gr.State() | |
gr.Markdown(f"<h1>{Config.APP_TITLE}</h1>") | |
gr.Markdown("Upload your data to receive a complete standard analysis, plus specialized dashboards that unlock automatically based on your data's content.") | |
with gr.Row(): | |
upload_button = gr.File(label="1. Upload Data File (CSV, Excel)", file_types=[".csv", ".xlsx", ".xls"], scale=3) | |
api_key_input = gr.Textbox(label="2. Enter Google Gemini API Key", type="password", scale=2) | |
analyze_button = gr.Button("β¨ Build My Dashboard", variant="primary", scale=1) | |
with gr.Tabs(): | |
with gr.Tab("π€ AI Narrative"): | |
ai_report_output = gr.Markdown("### Your AI-generated report will appear here...") | |
with gr.Tab("π Profile"): | |
profile_missing_df, profile_numeric_df, profile_categorical_df = gr.DataFrame(), gr.DataFrame(), gr.DataFrame() | |
with gr.Tab("π Overview Visuals"): | |
with gr.Row(): plot_types, plot_missing = gr.Plot(), gr.Plot() | |
plot_correlation = gr.Plot() | |
with gr.Tab("π¨ Interactive Explorer"): | |
with gr.Row(): | |
with gr.Column(scale=1): dd_hist_col = gr.Dropdown(label="Select Column for Histogram", interactive=True) | |
with gr.Column(scale=2): plot_histogram = gr.Plot() | |
with gr.Row(): | |
with gr.Column(scale=1): | |
dd_scatter_x = gr.Dropdown(label="X-Axis (Numeric)", interactive=True) | |
dd_scatter_y = gr.Dropdown(label="Y-Axis (Numeric)", interactive=True) | |
dd_scatter_color = gr.Dropdown(label="Color By (Optional)", interactive=True) | |
with gr.Column(scale=2): plot_scatter = gr.Plot() | |
with gr.Tab("π Column Deep-Dive"): | |
dd_drilldown_col = gr.Dropdown(label="Select Column to Analyze", interactive=True) | |
with gr.Row(): md_drilldown_stats, plot_drilldown = gr.Markdown(), gr.Plot() | |
with gr.Tab("β Time-Series Analysis", visible=False) as tab_timeseries: | |
with gr.Row(): | |
dd_ts_date = gr.Dropdown(label="Select Date/Time Column", interactive=True) | |
dd_ts_value = gr.Dropdown(label="Select Value Column", interactive=True) | |
plot_ts_decomp, md_ts_stats = gr.Plot(), gr.Markdown() | |
with gr.Tab("π Text Analysis", visible=False) as tab_text: | |
dd_text_col = gr.Dropdown(label="Select Text Column", interactive=True) | |
html_word_cloud = gr.HTML() | |
with gr.Tab("π§© Clustering (K-Means)", visible=False) as tab_cluster: | |
num_clusters = gr.Slider(minimum=2, maximum=10, value=4, step=1, label="Number of Clusters (K)", interactive=True) | |
plot_cluster, md_cluster_summary = gr.Plot(), gr.Markdown() | |
main_outputs = [ | |
state_analyzer, ai_report_output, | |
profile_missing_df, profile_numeric_df, profile_categorical_df, | |
plot_types, plot_missing, plot_correlation, | |
dd_hist_col, dd_scatter_x, dd_scatter_y, dd_scatter_color, dd_drilldown_col, | |
tab_timeseries, dd_ts_date, dd_ts_value, | |
tab_text, dd_text_col, | |
tab_cluster, num_clusters] | |
analyze_button.click(fn=run_full_analysis, inputs=[upload_button, api_key_input], outputs=main_outputs, show_progress="full") | |
dd_hist_col.change(fn=create_histogram, inputs=[state_analyzer, dd_hist_col], outputs=plot_histogram) | |
scatter_inputs = [state_analyzer, dd_scatter_x, dd_scatter_y, dd_scatter_color] | |
for dd in [dd_scatter_x, dd_scatter_y, dd_scatter_color]: dd.change(fn=create_scatterplot, inputs=scatter_inputs, outputs=plot_scatter) | |
dd_drilldown_col.change(fn=analyze_single_column, inputs=[state_analyzer, dd_drilldown_col], outputs=[md_drilldown_stats, plot_drilldown]) | |
ts_inputs = [state_analyzer, dd_ts_date, dd_ts_value] | |
for dd in [dd_ts_date, dd_ts_value]: dd.change(fn=lambda a, d, v: analyze_time_series(a.df, d, v), inputs=ts_inputs, outputs=[plot_ts_decomp, md_ts_stats]) | |
dd_text_col.change(fn=lambda a, t: generate_word_cloud(a.df, t), inputs=[state_analyzer, dd_text_col], outputs=html_word_cloud) | |
num_clusters.change(fn=lambda a, k: perform_clustering(a.df, a.metadata['numeric_cols'], k), inputs=[state_analyzer, num_clusters], outputs=[plot_cluster, md_cluster_summary]) | |
return demo | |
# --- Main Application Logic & Orchestration --- | |
def run_full_analysis(file_obj: gr.File, api_key: str) -> list: | |
if file_obj is None: raise gr.Error("CRITICAL: No file uploaded.") | |
if not api_key: raise gr.Error("CRITICAL: Gemini API key is missing.") | |
try: | |
logging.info(f"Processing uploaded file: {file_obj.name}") | |
df = pd.read_csv(file_obj.name) if file_obj.name.endswith('.csv') else pd.read_excel(file_obj.name) | |
if len(df) > Config.MAX_UI_ROWS: df = df.sample(n=Config.MAX_UI_ROWS, random_state=42) | |
analyzer = DataAnalyzer(df) | |
meta = analyzer.metadata | |
ai_context = {'is_timeseries': bool(meta['datetime_cols']), 'has_text': bool(meta['text_cols'])} | |
ai_report = analyzer.generate_ai_narrative(api_key, context=ai_context) | |
missing_df, num_df, cat_df = analyzer.get_profiling_tables() | |
fig_types, fig_missing, fig_corr = analyzer.get_overview_visuals() | |
update_hist_dd = gr.Dropdown(choices=meta['numeric_cols'], value=meta['numeric_cols'][0] if meta['numeric_cols'] else None) | |
update_scatter_x = gr.Dropdown(choices=meta['numeric_cols'], value=meta['numeric_cols'][0] if meta['numeric_cols'] else None) | |
update_scatter_y = gr.Dropdown(choices=meta['numeric_cols'], value=meta['numeric_cols'][1] if len(meta['numeric_cols']) > 1 else None) | |
update_scatter_color = gr.Dropdown(choices=meta['columns']) | |
update_drill_dd = gr.Dropdown(choices=meta['columns']) | |
show_ts_tab = gr.Tab(visible=bool(meta['datetime_cols'])) | |
update_ts_date_dd, update_ts_value_dd = gr.Dropdown(choices=meta['datetime_cols']), gr.Dropdown(choices=meta['numeric_cols']) | |
show_text_tab, update_text_dd = gr.Tab(visible=bool(meta['text_cols'])), gr.Dropdown(choices=meta['text_cols']) | |
show_cluster_tab, update_cluster_slider = gr.Tab(visible=len(meta['numeric_cols']) > 1), gr.Slider(visible=len(meta['numeric_cols']) > 1) | |
return [analyzer, ai_report, missing_df, num_df, cat_df, fig_types, fig_missing, fig_corr, | |
update_hist_dd, update_scatter_x, update_scatter_y, update_scatter_color, update_drill_dd, | |
show_ts_tab, update_ts_date_dd, update_ts_value_dd, | |
show_text_tab, update_text_dd, | |
show_cluster_tab, update_cluster_slider] | |
except Exception as e: | |
logging.error(f"A critical error occurred: {e}", exc_info=True) | |
raise gr.Error(f"Analysis Failed! Error: {str(e)}") | |
if __name__ == "__main__": | |
app_instance = create_ui() | |
app_instance.launch(debug=True, server_name="0.0.0.0") |